Commit 7e1d5e53 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.1

parents e3378b20 5f08050d
from aioprometheus import Gauge from vllm.logger import init_logger
from aioprometheus import Counter, Gauge, Histogram
import time
import numpy as np
from typing import List
from dataclasses import dataclass
logger = init_logger(__name__)
labels = {}
def add_global_metrics_labels(**kwargs):
labels.update(kwargs)
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions. # to extract the metrics definitions.
...@@ -9,12 +24,16 @@ gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s", ...@@ -9,12 +24,16 @@ gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s",
gauge_avg_generation_throughput = Gauge( gauge_avg_generation_throughput = Gauge(
"vllm:avg_generation_throughput_toks_per_s", "vllm:avg_generation_throughput_toks_per_s",
"Average generation throughput in tokens/s.") "Average generation throughput in tokens/s.")
counter_prompt_tokens = Counter("vllm:prompt_tokens_total",
"Number of prefill tokens processed.")
counter_generation_tokens = Counter("vllm:generation_tokens_total",
"Number of generation tokens processed.")
gauge_scheduler_running = Gauge( gauge_scheduler_running = Gauge(
"vllm:num_requests_running", "vllm:num_requests_running",
"Number of requests that is currently running for inference.") "Number of requests currently running on GPU.")
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped", gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped",
"Number requests swapped to CPU.") "Number of requests swapped to CPU.")
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting", gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
"Number of requests waiting to be processed.") "Number of requests waiting to be processed.")
...@@ -24,28 +43,131 @@ gauge_gpu_cache_usage = Gauge( ...@@ -24,28 +43,131 @@ gauge_gpu_cache_usage = Gauge(
gauge_cpu_cache_usage = Gauge( gauge_cpu_cache_usage = Gauge(
"vllm:cpu_cache_usage_perc", "vllm:cpu_cache_usage_perc",
"CPU KV-cache usage. 1 means 100 percent usage.") "CPU KV-cache usage. 1 means 100 percent usage.")
histogram_time_to_first_token = Histogram(
"vllm:time_to_first_token_seconds",
"Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0,
2.5, 5.0, 7.5, 10.0
])
histogram_time_per_output_tokens = Histogram(
"vllm:time_per_output_token_seconds",
"Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5
])
histogram_e2e_request_latency = Histogram(
"vllm:e2e_request_latency_seconds",
"Histogram of end to end request latency in seconds.",
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# end-metrics-definitions # end-metrics-definitions
labels = {}
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
def add_global_metrics_labels(**kwargs): # System stats.
labels.update(kwargs) num_running: int
num_waiting: int
num_swapped: int
gpu_cache_usage: float
cpu_cache_usage: float
# Raw stats from last model iteration.
num_prompt_tokens: int
num_generation_tokens: int
time_to_first_tokens: List[float]
time_per_output_tokens: List[float]
time_e2e_requests: List[float]
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float) -> None:
# Metadata for logging locally.
self.last_local_log = time.monotonic()
self.local_interval = local_interval
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))
def _local_interval_elapsed(self, now: float) -> bool:
elapsed_time = now - self.last_local_log
return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges.
gauge_scheduler_running.set(labels, stats.num_running)
gauge_scheduler_swapped.set(labels, stats.num_swapped)
gauge_scheduler_waiting.set(labels, stats.num_waiting)
gauge_gpu_cache_usage.set(labels, stats.gpu_cache_usage)
gauge_cpu_cache_usage.set(labels, stats.cpu_cache_usage)
# Add to token counters.
counter_prompt_tokens.add(labels, stats.num_prompt_tokens)
counter_generation_tokens.add(labels, stats.num_generation_tokens)
# Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens:
histogram_time_to_first_token.observe(labels, ttft)
for tpot in stats.time_per_output_tokens:
histogram_time_per_output_tokens.observe(labels, tpot)
for e2e in stats.time_e2e_requests:
histogram_e2e_request_latency.observe(labels, e2e)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side.
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
gauge_avg_prompt_throughput.set(labels, prompt_throughput)
gauge_avg_generation_throughput.set(labels, generation_throughput)
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus.
self._log_prometheus(stats)
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens)
self.num_generation_tokens.append(stats.num_generation_tokens)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now=stats.now)
self._log_prometheus_interval(
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, "
f"Running: {stats.num_running} reqs, "
f"Swapped: {stats.num_swapped} reqs, "
f"Pending: {stats.num_waiting} reqs, "
f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%")
def record_metrics( # Reset tracked stats for next interval.
avg_prompt_throughput: float, self.num_prompt_tokens = []
avg_generation_throughput: float, self.num_generation_tokens = []
scheduler_running: int, self.last_local_log = stats.now
scheduler_swapped: int,
scheduler_waiting: int,
gpu_cache_usage: float,
cpu_cache_usage: float,
):
gauge_avg_prompt_throughput.set(labels, avg_prompt_throughput)
gauge_avg_generation_throughput.set(labels, avg_generation_throughput)
gauge_scheduler_running.set(labels, scheduler_running)
gauge_scheduler_swapped.set(labels, scheduler_swapped)
gauge_scheduler_waiting.set(labels, scheduler_waiting)
gauge_gpu_cache_usage.set(labels, gpu_cache_usage)
gauge_cpu_cache_usage.set(labels, cpu_cache_usage)
import pickle
from typing import Optional, List, Tuple, TYPE_CHECKING from typing import Optional, List, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
...@@ -18,6 +20,11 @@ try: ...@@ -18,6 +20,11 @@ try:
from transformers.dynamic_module_utils import init_hf_modules from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules() init_hf_modules()
self.worker = None self.worker = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn):
self.worker = worker_init_fn() self.worker = worker_init_fn()
...@@ -40,6 +47,17 @@ try: ...@@ -40,6 +47,17 @@ try:
def set_cuda_visible_devices(self, device_ids) -> None: def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids) set_cuda_visible_devices(device_ids)
def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model()
output = pickle.dumps(output)
return output
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
......
...@@ -111,13 +111,13 @@ class LLM: ...@@ -111,13 +111,13 @@ class LLM:
def get_tokenizer( def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer.tokenizer
def set_tokenizer( def set_tokenizer(
self, self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None: ) -> None:
self.llm_engine.tokenizer = tokenizer self.llm_engine.tokenizer.tokenizer = tokenizer
def generate( def generate(
self, self,
......
import asyncio import asyncio
import time import time
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -19,8 +19,8 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing ...@@ -19,8 +19,8 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__) logger = init_logger(__name__)
TypeTokenIDs = list[int] TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[dict[int, float]]] TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
...@@ -29,7 +29,7 @@ async def completion_stream_generator( ...@@ -29,7 +29,7 @@ async def completion_stream_generator(
request: CompletionRequest, request: CompletionRequest,
raw_request: Request, raw_request: Request,
on_abort, on_abort,
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[Tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn, create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str, request_id: str,
created_time: int, created_time: int,
...@@ -126,7 +126,7 @@ async def completion_stream_generator( ...@@ -126,7 +126,7 @@ async def completion_stream_generator(
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> tuple[bool, list]: def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following # get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays." # "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False prompt_is_tokens = False
...@@ -151,7 +151,7 @@ def parse_prompt_format(prompt) -> tuple[bool, list]: ...@@ -151,7 +151,7 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
def request_output_to_completion_response( def request_output_to_completion_response(
final_res_batch: list[RequestOutput], final_res_batch: List[RequestOutput],
request: CompletionRequest, request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn, create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str, request_id: str,
...@@ -302,7 +302,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -302,7 +302,7 @@ class OpenAIServingCompletion(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[tuple[ result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators) int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import logging import logging
import sys import sys
import os
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
...@@ -50,7 +51,7 @@ _setup_logger() ...@@ -50,7 +51,7 @@ _setup_logger()
def init_logger(name: str): def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger
...@@ -4,8 +4,7 @@ import logging ...@@ -4,8 +4,7 @@ import logging
import math import math
import os import os
import re import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
Union)
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -20,36 +19,6 @@ from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule ...@@ -20,36 +19,6 @@ from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: The mappings below should be moved to individual model classes.
PACKED_MODULES_CFG = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
TARGET_MODULES_QKV = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
...@@ -169,6 +138,8 @@ class LoRAModel: ...@@ -169,6 +138,8 @@ class LoRAModel:
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors.""" """Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl() pin_memory = str(device) == "cpu" and not in_wsl()
...@@ -179,11 +150,11 @@ class LoRAModel: ...@@ -179,11 +150,11 @@ class LoRAModel:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
embeddings_module = next( embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), (k for k in embedding_modules if k in module_name),
None) None)
if embeddings_module: if embeddings_module:
lora_embeddings_tensor = embeddings[ lora_embeddings_tensor = embeddings[
EMBEDDING_MODULES[embeddings_module]].to( embedding_modules[embeddings_module]].to(
device=device, dtype=dtype) device=device, dtype=dtype)
if pin_memory: if pin_memory:
lora_embeddings_tensor = ( lora_embeddings_tensor = (
...@@ -201,7 +172,7 @@ class LoRAModel: ...@@ -201,7 +172,7 @@ class LoRAModel:
loras[module_name].lora_b = tensor.to(device=device, loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t() dtype=dtype).t()
if any(name in module_name if any(name in module_name
for name in EMBEDDING_PADDING_MODULES for name in embedding_padding_modules
) and target_embedding_padding is not None: ) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1] assert target_embedding_padding >= lora_b.shape[1]
...@@ -218,12 +189,15 @@ class LoRAModel: ...@@ -218,12 +189,15 @@ class LoRAModel:
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None) -> "LoRAModel": target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.""" """Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
...@@ -260,6 +234,8 @@ class LoRAModel: ...@@ -260,6 +234,8 @@ class LoRAModel:
dtype=dtype, dtype=dtype,
embeddings=embeddings, embeddings=embeddings,
target_embedding_padding=target_embedding_padding, target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
) )
...@@ -273,8 +249,6 @@ class LoRAModelManager: ...@@ -273,8 +249,6 @@ class LoRAModelManager:
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
): ):
"""Create a LoRAModelManager and adapter for a given model. """Create a LoRAModelManager and adapter for a given model.
...@@ -286,13 +260,6 @@ class LoRAModelManager: ...@@ -286,13 +260,6 @@ class LoRAModelManager:
in a single batch. in a single batch.
vocab_size: the vocab size of the model. vocab_size: the vocab size of the model.
lora_config: the LoRA configuration. lora_config: the LoRA configuration.
lora_target_modules: the target modules patterns to be adapted.
Support both single module name and a list of module names.
packed_modules_mapping: the mapping for packed modules. vLLM
packs some modules into one module, e.g., qkv_proj
is packed of q_proj, k_proj, and v_proj. These modules
have a single layer in the original model, but they are split
into multiple layers in the adapted model.
""" """
self.lora_config = lora_config self.lora_config = lora_config
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
...@@ -320,11 +287,11 @@ class LoRAModelManager: ...@@ -320,11 +287,11 @@ class LoRAModelManager:
self.indices_len = [None] * 4 self.indices_len = [None] * 4
self.model: nn.Module = model self.model: nn.Module = model
self.lora_target_modules: List[str] = ([ if hasattr(self.model, "supported_lora_modules"):
lora_target_modules self.supported_lora_modules = copy.deepcopy(
] if isinstance(lora_target_modules, str) else lora_target_modules) self.model.supported_lora_modules)
self.lora_target_modules = copy.deepcopy(lora_target_modules) self.packed_modules_mapping = copy.deepcopy(
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {} self._registered_loras: Dict[int, LoRAModel] = {}
...@@ -468,7 +435,11 @@ class LoRAModelManager: ...@@ -468,7 +435,11 @@ class LoRAModelManager:
assert isinstance(module, BaseLayerWithLoRA) assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module self.modules[module_name] = module
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: def create_dummy_lora(
self,
lora_id: int,
rank: int,
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}) model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
...@@ -477,7 +448,7 @@ class LoRAModelManager: ...@@ -477,7 +448,7 @@ class LoRAModelManager:
continue continue
parts = module_name.split(".") parts = module_name.split(".")
if module_name not in self.packed_modules: if module_name not in self.packed_modules:
if parts[-1] in EMBEDDING_MODULES: if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size + input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size") hasattr(module.base_layer, "org_vocab_size")
...@@ -531,7 +502,7 @@ class LoRAModelManager: ...@@ -531,7 +502,7 @@ class LoRAModelManager:
re.match( re.match(
r".*\.{target_module}$".format(target_module=target_module), r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name module_name) or target_module == module_name
for target_module in self.lora_target_modules) for target_module in self.supported_lora_modules)
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".") parts = module_full_name.split(".")
...@@ -586,12 +557,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -586,12 +557,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
): ):
super().__init__(model, max_num_seqs, max_num_batched_tokens, super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, lora_target_modules, vocab_size, lora_config)
packed_modules_mapping)
self._registered_loras: LoRALRUCache = LoRALRUCache( self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora) self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache( self._active_loras: LoRALRUCache = LoRALRUCache(
...@@ -637,11 +605,10 @@ def create_lora_manager( ...@@ -637,11 +605,10 @@ def create_lora_manager(
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager: **kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model.""" """Create a LoRA adapter for a given model."""
if not getattr(model, "supports_lora", False): if not hasattr(model, "supported_lora_modules"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.") raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls( lora_manager = lora_manager_cls(
model=model, model=model,
...@@ -649,6 +616,5 @@ def create_lora_manager( ...@@ -649,6 +616,5 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size, vocab_size=vocab_size,
lora_config=lora_config, lora_config=lora_config,
lora_target_modules=target_modules,
**kwargs) **kwargs)
return lora_manager return lora_manager
...@@ -4,173 +4,167 @@ from typing import Optional ...@@ -4,173 +4,167 @@ from typing import Optional
import torch import torch
import_exc = None
def _raise_import_error(e):
try: if torch.cuda.get_device_capability() < (8, 0):
import vllm._punica_C as punica_kernels raise ImportError(
except ImportError as e: "punica LoRA kernels require compute capability >= 8.0") from e
import_exc = e else:
raise ImportError(
if import_exc is None: "punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
def bgmv( "was set.") from e
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor, def bgmv(
indicies: torch.LongTensor, y: torch.Tensor,
layer_idx: int, x: torch.Tensor,
scale: float, w_t_all: torch.Tensor,
): indicies: torch.LongTensor,
""" layer_idx: int,
Semantics: scale: float,
y[i] += ( ):
x[i].unsqueeze(0) """
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) Semantics:
* scale y[i] += (
).squeeze(0) x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
Args: * scale
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. ).squeeze(0)
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight Args:
matrices. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
indicies: Shape: `[B]`. Indices of the weight matrices. x: Shape: `[B, H1]`. Input vectors.
layer_idx: Layer index of the weight matrices. w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
scale: Scaling factor. matrices.
""" indicies: Shape: `[B]`. Indices of the weight matrices.
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
def add_lora(y: torch.Tensor, """
x: torch.Tensor, try:
wa_t_all: torch.Tensor, import vllm._punica_C as punica_kernels
wb_t_all: torch.Tensor, except ImportError as e:
indicies: torch.LongTensor, _raise_import_error(e)
layer_idx: int,
scale: float, punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
*,
buffer: Optional[torch.Tensor] = None):
""" def add_lora(y: torch.Tensor,
Semantics: x: torch.Tensor,
y[i] += ( wa_t_all: torch.Tensor,
x[i].unsqueeze(0) wb_t_all: torch.Tensor,
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) indicies: torch.LongTensor,
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) layer_idx: int,
* scale scale: float,
).squeeze(0) *,
buffer: Optional[torch.Tensor] = None):
Args: """
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. Semantics:
x: Shape: `[B, H1]`. Input vectors. y[i] += (
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed x[i].unsqueeze(0)
LoRA A matrices. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
LoRA B matrices. * scale
indicies: Shape: `[B]`. Indices of the LoRA weights. ).squeeze(0)
layer_idx: Layer index of LoRA weights.
scale: Scaling factor. Args:
buffer: Optional. Shape: `[B, R]`. Temporary buffer. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
""" x: Shape: `[B, H1]`. Input vectors.
r = wb_t_all.size(-1) wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
if buffer is None: LoRA A matrices.
# We set the buffer to be float32 by default to avoid wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
# numerical innacuracies that would otherwise happen LoRA B matrices.
# due to downcasting. indicies: Shape: `[B]`. Indices of the LoRA weights.
buffer = torch.zeros((x.size(0), r), layer_idx: Layer index of LoRA weights.
dtype=torch.float32, scale: Scaling factor.
device=x.device) buffer: Optional. Shape: `[B, R]`. Temporary buffer.
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, """
1.0) try:
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, import vllm._punica_C as punica_kernels
scale) except ImportError as e:
_raise_import_error(e)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor, r = wb_t_all.size(-1)
wa_t_all: torch.Tensor, if buffer is None:
wb_t_all: torch.Tensor, # We set the buffer to be float32 by default to avoid
indicies: torch.LongTensor, # numerical innacuracies that would otherwise happen
layer_idx: int, # due to downcasting.
scale: float, buffer = torch.zeros((x.size(0), r),
y_offset: int, dtype=torch.float32,
y_slice_size: int, device=x.device)
*, punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
buffer: Optional[torch.Tensor] = None): punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
""" scale)
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
def add_lora_slice(y: torch.Tensor,
Semantics: x: torch.Tensor,
y[i] += ( wa_t_all: torch.Tensor,
x[i].unsqueeze(0) wb_t_all: torch.Tensor,
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) indicies: torch.LongTensor,
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) layer_idx: int,
* scale scale: float,
).squeeze(0) y_offset: int,
y_slice_size: int,
Args: *,
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. buffer: Optional[torch.Tensor] = None):
x: Shape: `[B, H1]`. Input vectors. """
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed Same as `add_lora` but you can operate on slices of y.
LoRA A matrices. Pass whole y, define y_offset and y_slice_size.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices. Semantics:
indicies: Shape: `[B]`. Indices of the LoRA weights. y[i] += (
layer_idx: Layer index of LoRA weights. x[i].unsqueeze(0)
scale: Scaling factor. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
y_offset: Offset to apply to the starting column of y. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
y_slice_size: Size of the y column slice. * scale
""" ).squeeze(0)
r = wb_t_all.size(-1)
if buffer is None: Args:
# We set the buffer to be float32 by default to avoid y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
# numerical inaccuracies that would otherwise happen x: Shape: `[B, H1]`. Input vectors.
# due to downcasting. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
buffer = torch.zeros((x.size(0), r), LoRA A matrices.
dtype=torch.float32, wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
device=x.device) LoRA B matrices.
punica_kernels.dispatch_bgmv_low_level( indicies: Shape: `[B]`. Indices of the LoRA weights.
buffer, layer_idx: Layer index of LoRA weights.
x, scale: Scaling factor.
wa_t_all, y_offset: Offset to apply to the starting column of y.
indicies, y_slice_size: Size of the y column slice.
layer_idx, """
1.0, try:
x.size(1), import vllm._punica_C as punica_kernels
buffer.size(1), except ImportError as e:
0, _raise_import_error(e)
)
punica_kernels.dispatch_bgmv_low_level( r = wb_t_all.size(-1)
y, if buffer is None:
buffer, # We set the buffer to be float32 by default to avoid
wb_t_all, # numerical inaccuracies that would otherwise happen
indicies, # due to downcasting.
layer_idx, buffer = torch.zeros((x.size(0), r),
scale, dtype=torch.float32,
buffer.size(1), device=x.device)
y_slice_size, punica_kernels.dispatch_bgmv_low_level(
y_offset, buffer,
) x,
wa_t_all,
else: indicies,
layer_idx,
def _raise_exc( 1.0,
*args, # pylint: disable=unused-argument x.size(1),
**kwargs # pylint: disable=unused-argument buffer.size(1),
): 0,
if torch.cuda.get_device_capability() < (8, 0): )
raise ImportError("punica LoRA kernels require compute " punica_kernels.dispatch_bgmv_low_level(
"capability>=8.0") from import_exc y,
else: buffer,
raise ImportError( wb_t_all,
"punica LoRA kernels could not be imported. If you built vLLM " indicies,
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " layer_idx,
"was set.") from import_exc scale,
buffer.size(1),
bgmv = _raise_exc y_slice_size,
add_lora = _raise_exc y_offset,
add_lora_slice = _raise_exc )
__all__ = [
"bgmv",
"add_lora",
"add_lora_slice",
]
import logging import logging
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional, Set, Type, Union from typing import Any, Dict, List, Optional, Set, Type
import torch import torch
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
...@@ -13,7 +13,7 @@ from vllm.config import LoRAConfig ...@@ -13,7 +13,7 @@ from vllm.config import LoRAConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkerLoRAManager(ABC): class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side.""" """Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
...@@ -33,7 +33,6 @@ class WorkerLoRAManager(ABC): ...@@ -33,7 +33,6 @@ class WorkerLoRAManager(ABC):
def create_lora_manager( def create_lora_manager(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any: ) -> Any:
... ...
...@@ -63,7 +62,7 @@ class WorkerLoRAManager(ABC): ...@@ -63,7 +62,7 @@ class WorkerLoRAManager(ABC):
... ...
class WorkerLoRAManager(WorkerLoRAManager): class WorkerLoRAManager(AbstractWorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side. """WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already Every request, the requested LoRAs will be loaded (unless they are already
...@@ -78,10 +77,14 @@ class WorkerLoRAManager(WorkerLoRAManager): ...@@ -78,10 +77,14 @@ class WorkerLoRAManager(WorkerLoRAManager):
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
device: torch.device, device: torch.device,
embedding_modules: Dict[str, str],
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel, lora_model_cls: Type[LoRAModel] = LoRAModel,
): ):
self._lora_manager: Optional[LoRAModelManager] = None self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device) lora_config, device)
...@@ -92,13 +95,11 @@ class WorkerLoRAManager(WorkerLoRAManager): ...@@ -92,13 +95,11 @@ class WorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager( def create_lora_manager(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any: ) -> Any:
lora_manager = create_lora_manager( lora_manager = create_lora_manager(
model, model,
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
target_modules=target_modules,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._lora_manager_cls,
...@@ -142,6 +143,8 @@ class WorkerLoRAManager(WorkerLoRAManager): ...@@ -142,6 +143,8 @@ class WorkerLoRAManager(WorkerLoRAManager):
dtype=self.lora_config.lora_dtype, dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size + target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size, self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
...@@ -162,7 +165,7 @@ class WorkerLoRAManager(WorkerLoRAManager): ...@@ -162,7 +165,7 @@ class WorkerLoRAManager(WorkerLoRAManager):
return False return False
return self._lora_manager.add_lora( return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id, self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank)) rank, self.embedding_modules))
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
...@@ -195,11 +198,9 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -195,11 +198,9 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager( def create_lora_manager(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any: ) -> Any:
lora_manager = create_lora_manager( lora_manager = create_lora_manager(
model, model,
target_modules=target_modules,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._lora_manager_cls,
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
......
...@@ -89,9 +89,7 @@ class ScaledActivation(nn.Module): ...@@ -89,9 +89,7 @@ class ScaledActivation(nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter( self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, torch.empty(intermediate_size_per_partition, dtype=params_dtype))
dtype=params_dtype,
device="cuda"))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
......
"""Multi-head attention.""" """Multi-head attention."""
from typing import List, Optional from typing import List, Optional
import importlib
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
...@@ -58,6 +59,40 @@ class PagedAttention(nn.Module): ...@@ -58,6 +59,40 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
self.use_ref_attention = self.check_use_ref_attention()
def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -137,6 +172,16 @@ class PagedAttention(nn.Module): ...@@ -137,6 +172,16 @@ class PagedAttention(nn.Module):
self.alibi_slopes, self.num_kv_heads, batch_size, self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype) seq_len, query.dtype)
if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability. # them in the future for code readability.
if self.alibi_slopes is None: if self.alibi_slopes is None:
...@@ -200,7 +245,7 @@ def _make_alibi_bias( ...@@ -200,7 +245,7 @@ def _make_alibi_bias(
seq_len: int, seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype, device="cuda") bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
......
...@@ -4,6 +4,7 @@ import triton ...@@ -4,6 +4,7 @@ import triton
import triton.language as tl import triton.language as tl
from vllm._C import ops from vllm._C import ops
from vllm.utils import is_hip
@triton.jit @triton.jit
...@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict): mul_routed_weight: bool, top_k: int, config: dict):
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
) )
def fused_moe(hidden_states: torch.Tensor, def fused_moe(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weights: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, gating_output: torch.Tensor,
inplace=False): topk: int,
renormalize: bool,
inplace: bool = False,
) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...@@ -223,22 +226,59 @@ def fused_moe(hidden_states: torch.Tensor, ...@@ -223,22 +226,59 @@ def fused_moe(hidden_states: torch.Tensor,
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts. - gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk_ids (torch.Tensor): The indices of the top-k selected experts. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = { config = {
'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_N': 64,
......
...@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
params_dtype: torch.dtype) -> Dict[str, Any]: params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
...@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
self.register_parameter(name, weight) self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=self.params_dtype))
device=torch.cuda.current_device(),
dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0}) set_weight_attrs(self.bias, {"output_dim": 0})
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
...@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
...@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=params_dtype))
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
......
...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -148,8 +145,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -148,8 +145,8 @@ class AWQLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"] scales = weights["scales"]
qzeros = weights["qzeros"]
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
......
...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
i // self.quant_config.group_size i // self.quant_config.group_size
for i in range(input_size_per_partition) for i in range(input_size_per_partition)
], ],
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
output_size, output_size,
self.quant_config.weight_bits**2, self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip(): if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float) out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16) out = out_f.to(dtype=torch.float16)
else: else:
# NOTE: The output tensor should be zero-initialized. # NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None: if bias is not None:
......
...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module): ...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
# create the cache on GPU for faster initialization. This may cause # create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours. # a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange( inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
self.rotary_dim))
return inv_freq return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache.""" """Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, t = torch.arange(self.max_position_embeddings, dtype=torch.float)
dtype=torch.float,
device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
# Thus, the maximum length after applying the rope scaling is # Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor. # self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
(self.scaling_factor - 1))**(self.rotary_dim / (self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2)) (self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base) inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -297,9 +294,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -297,9 +294,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style) is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange( pos_freqs = self.base**(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim) self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
...@@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.max_position_embeddings) self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask( inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float, low, high, self.rotary_dim // 2,
device="cuda")) * self.extrapolation_factor dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * ( inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq return inv_freq
...@@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor, t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32) dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale) cos = (freqs.cos() * self.mscale)
......
...@@ -618,7 +618,9 @@ if triton.__version__ >= "2.1.0": ...@@ -618,7 +618,9 @@ if triton.__version__ >= "2.1.0":
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
alibi_slopes=None): alibi_slopes=None):
BLOCK = 128
cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
......
...@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
self.embedding_dim, self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.weight, { set_weight_attrs(self.weight, {
"parallel_dim": 0, "parallel_dim": 0,
...@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"parallel_dim": 0, "parallel_dim": 0,
......
...@@ -5,7 +5,7 @@ from typing import Optional, Type ...@@ -5,7 +5,7 @@ from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
...@@ -38,16 +38,14 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: ...@@ -38,16 +38,14 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig, def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module: lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config) model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
linear_method = None linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config.quantization, quant_config = get_quant_config(model_config)
model_config.model,
model_config.hf_config,
model_config.download_dir)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
...@@ -67,8 +65,8 @@ def get_model(model_config: ModelConfig, ...@@ -67,8 +65,8 @@ def get_model(model_config: ModelConfig,
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
with torch.device("cuda"): with torch.device(device_config.device):
if getattr(model_class, "supports_lora", False): if hasattr(model_class, "supported_lora_modules"):
model = model_class(model_config.hf_config, linear_method, model = model_class(model_config.hf_config, linear_method,
lora_config) lora_config)
elif lora_config: elif lora_config:
......
...@@ -10,8 +10,8 @@ logger = init_logger(__name__) ...@@ -10,8 +10,8 @@ logger = init_logger(__name__)
# Architecture -> (module, class). # Architecture -> (module, class).
_MODELS = { _MODELS = {
"AquilaModel": ("aquila", "AquilaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
...@@ -24,7 +24,8 @@ _MODELS = { ...@@ -24,7 +24,8 @@ _MODELS = {
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
...@@ -40,7 +41,6 @@ _MODELS = { ...@@ -40,7 +41,6 @@ _MODELS = {
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM")
} }
# Models not supported by ROCm. # Models not supported by ROCm.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment