Commit 35393439 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge dtk24.04-v0.3.3

parents 7c4f76e3 f26ecef8
...@@ -11,7 +11,7 @@ prompts = [ ...@@ -11,7 +11,7 @@ prompts = [
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m") llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -5,7 +5,7 @@ starlette ...@@ -5,7 +5,7 @@ starlette
requests requests
py-cpuinfo py-cpuinfo
psutil psutil
ray == 2.9.3 ray == 2.9.1
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
tokenizers>=0.15.0 tokenizers>=0.15.0
......
...@@ -20,7 +20,8 @@ NUM_BLOCKS = 4321 # Arbitrary values for testing ...@@ -20,7 +20,8 @@ NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16] # ] if not is_hip() else [torch.half, torch.bfloat16]
] if not is_hip() else [torch.half]
NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
...@@ -32,7 +33,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256 ...@@ -32,7 +33,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
......
...@@ -23,7 +23,7 @@ SEEDS = [0] ...@@ -23,7 +23,7 @@ SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
......
...@@ -68,8 +68,8 @@ def test_get_prompt_logprobs( ...@@ -68,8 +68,8 @@ def test_get_prompt_logprobs(
logprob = sample_logprob.logprob logprob = sample_logprob.logprob
torch.testing.assert_close(logprob, torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(), hf_logprob[i][-1][token_id].item(),
atol=1e-2, atol=1e-1,
rtol=1e-2) rtol=1e-1)
assert isinstance(sample_logprob.decoded_token, str), ( assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned " "The token should be decoded by the time it is returned "
" to the user.") " to the user.")
...@@ -84,3 +84,4 @@ def test_max_logprobs(): ...@@ -84,3 +84,4 @@ def test_max_logprobs():
bad_sampling_params = SamplingParams(logprobs=2) bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params) runner.generate(["Hello world"], sampling_params=bad_sampling_params)
...@@ -14,6 +14,9 @@ from vllm.model_executor.parallel_utils.utils import ( ...@@ -14,6 +14,9 @@ from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim) divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
import os
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -55,6 +58,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -55,6 +58,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, input_size_per_partition: int, def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
...@@ -76,7 +80,11 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -76,7 +80,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias is not None: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
return F.linear(x, weight, bias) if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1)
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module): class ReplicatedLinear(torch.nn.Module):
...@@ -195,6 +203,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -195,6 +203,7 @@ class ColumnParallelLinear(torch.nn.Module):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -206,6 +215,9 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -206,6 +215,9 @@ class ColumnParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -259,6 +271,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -259,6 +271,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method) skip_bias_add, params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -313,8 +326,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -313,8 +326,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if self.use_llama_nn:
shard_size) param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -325,8 +342,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -325,8 +342,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -385,6 +410,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -385,6 +410,7 @@ class QKVParallelLinear(ColumnParallelLinear):
2 * self.num_kv_heads) * tp_size * self.head_size 2 * self.num_kv_heads) * tp_size * self.head_size
super().__init__(input_size, output_size, bias, False, skip_bias_add, super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method) params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -450,7 +476,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -450,7 +476,11 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
...@@ -466,8 +496,16 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -466,8 +496,16 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v":
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
...@@ -545,6 +583,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -545,6 +583,7 @@ class RowParallelLinear(torch.nn.Module):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -556,6 +595,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -556,6 +595,9 @@ class RowParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.models import ModelRegistry ...@@ -10,6 +10,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
import os
_VISION_MODEL_CLASSES = [ _VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
...@@ -28,6 +29,8 @@ def _set_default_torch_dtype(dtype: torch.dtype): ...@@ -28,6 +29,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture( def _get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
...@@ -171,7 +171,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: ...@@ -171,7 +171,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
return _async_wrapper return _async_wrapper
def get_ip() -> str: def get_ip() -> str:
host_ip = os.environ.get("HOST_IP") host_ip = os.environ.get("HOST_IP")
if host_ip: if host_ip:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment