Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a62aaf1d
Unverified
Commit
a62aaf1d
authored
Apr 26, 2024
by
Cody Yu
Committed by
GitHub
Apr 26, 2024
Browse files
[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
parent
603ad848
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
328 additions
and
299 deletions
+328
-299
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+1
-1
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+2
-2
vllm/lora/layers.py
vllm/lora/layers.py
+13
-17
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+88
-81
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+8
-5
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+11
-8
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+28
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+20
-40
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+11
-8
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+8
-5
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+14
-10
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+18
-20
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+7
-6
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+22
-21
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+17
-16
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+19
-18
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+17
-16
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+18
-17
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+4
-3
No files found.
tests/quantization/test_fp8.py
View file @
a62aaf1d
...
...
@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
linear
_method
,
Fp8LinearMethod
)
assert
isinstance
(
fc1
.
quant
_method
,
Fp8LinearMethod
)
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
tests/tensorizer_loader/test_tensorizer.py
View file @
a62aaf1d
...
...
@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance
.
deserialize
.
return_value
=
MagicMock
()
result
=
load_with_tensorizer
(
tensorizer_config
,
linear
_method
=
mock_linear_method
)
quant
_method
=
mock_linear_method
)
mock_agent
.
assert_called_once_with
(
tensorizer_config
,
linear
_method
=
mock_linear_method
)
quant
_method
=
mock_linear_method
)
mock_agent_instance
.
deserialize
.
assert_called_once
()
assert
result
==
mock_agent_instance
.
deserialize
.
return_value
...
...
vllm/lora/layers.py
View file @
a62aaf1d
...
...
@@ -389,10 +389,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
indices
=
base_indices
self
.
indices_len
=
indices_len
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
_apply_lora
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -416,7 +415,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if
not
self
.
base_layer
.
skip_bias_add
else
None
)
# Matrix multiply.
output_parallel
=
self
.
apply
_weights
(
input_
,
bias
)
output_parallel
=
self
.
apply
(
input_
,
bias
)
if
self
.
base_layer
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
...
@@ -523,10 +522,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -765,10 +763,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index
,
0
,
:
lora_a
[
2
].
shape
[
1
],
:
lora_a
[
2
].
shape
[
0
]].
copy_
(
lora_a
[
2
].
T
,
non_blocking
=
True
)
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -862,9 +859,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
indices
=
base_indices
self
.
indices_len
=
indices_len
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
)
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
_apply_lora
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -897,7 +893,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
output_parallel
=
self
.
apply
_weights
(
input_parallel
)
output_parallel
=
self
.
apply
(
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
...
...
vllm/model_executor/layers/linear.py
View file @
a62aaf1d
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
...
@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
...
...
@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
class
LinearMethodBase
(
ABC
):
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
...
...
@@ -50,7 +51,7 @@ class LinearMethodBase(ABC):
raise
NotImplementedError
@
abstractmethod
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -59,13 +60,6 @@ class LinearMethodBase(ABC):
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
...
...
@@ -92,7 +86,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
F
.
linear
(
x
,
weight
,
bias
)
class
Replicated
Linear
(
torch
.
nn
.
Module
):
"""
Replicated
linear layer.
class
Linear
Base
(
torch
.
nn
.
Module
):
"""
Base
linear layer.
Args:
input_size: input dimension of the linear layer.
...
...
@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
if
quant_config
is
None
:
self
.
quant_method
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
...
...
@@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output
=
self
.
linear
_method
.
apply
_weights
(
self
,
x
,
bias
)
output
=
self
.
quant
_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Modul
e
):
class
ColumnParallelLinear
(
LinearBas
e
):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
...
...
@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
...
...
@@ -184,28 +208,20 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
super
().
__init__
()
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
linear_method
=
linear_method
self
.
linear_method
.
create_weights
(
self
,
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
x
//
tp_size
for
x
in
output_sizes
],
self
.
input_size
,
...
...
@@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
output_parallel
=
self
.
linear
_method
.
apply
_weights
(
self
,
input_
,
bias
)
output_parallel
=
self
.
quant
_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
...
@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
def
__init__
(
...
...
@@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
skip_bias_add
,
params_dtype
,
linear_method
,
skip_bias_add
,
params_dtype
,
quant_config
,
self
.
output_sizes
)
def
weight_loader
(
self
,
...
...
@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
def
__init__
(
...
...
@@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
...
...
@@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear):
]
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
params_dtype
,
linear_method
,
output_sizes
)
params_dtype
,
quant_config
,
output_sizes
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
torch
.
nn
.
Modul
e
):
class
RowParallelLinear
(
LinearBas
e
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
...
...
@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
def
__init__
(
...
...
@@ -552,26 +568,18 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
skip_bias_add
=
skip_bias_add
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_method
.
create_weights
(
self
,
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
[
self
.
output_size
],
self
.
input_size
,
...
...
@@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module):
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
,
input_parallel
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
a62aaf1d
...
...
@@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
F
P
8Config
from
vllm.model_executor.layers.quantization.fp8
import
F
p
8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
...
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"fp8"
:
F
P
8Config
,
"fp8"
:
F
p
8Config
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"marlin"
:
MarlinConfig
,
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
a62aaf1d
...
...
@@ -9,10 +9,10 @@ import torch.nn.functional as F
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
def
get_int_dtype
(
nbits
:
int
)
->
torch
.
dtype
:
...
...
@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
def
get_linear_method
(
self
)
->
"AQLMLinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AQLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AQLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply
_weights
(
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
a62aaf1d
...
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
AWQConfig
(
QuantizationConfig
):
...
...
@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
return
cls
(
weight_bits
,
group_size
,
zero_point
)
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AWQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AWQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
...
...
@@ -147,7 +150,7 @@ class AWQLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
a62aaf1d
...
...
@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
from
typing
import
Any
,
Dict
,
List
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
class
QuantizeMethodBase
(
ABC
):
"""Base class for different quantized methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
*
weight_args
,
**
extra_weight_attrs
):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
QuantizationConfig
(
ABC
):
...
...
@@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
"quantization config."
)
@
abstractmethod
def
get_
linear
_method
(
self
)
->
Linear
MethodBase
:
"""Get the
linear
method to use for the quantized
linear
layer."""
def
get_
quant
_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Quantize
MethodBase
:
"""Get the
quantize
method to use for the quantized layer."""
raise
NotImplementedError
@
abstractmethod
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
a62aaf1d
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
.model_exe
cuto
r.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm
import
_
cu
s
to
m_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
F
P
8Config
(
QuantizationConfig
):
class
F
p
8Config
(
QuantizationConfig
):
"""Config class for FP8."""
@
classmethod
...
...
@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"F
P
8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"F
p
8Config"
:
return
cls
()
def
get_linear_method
(
self
)
->
"Fp8LinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
F
P
8Config
):
def
__init__
(
self
,
quant_config
:
F
p
8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
...
...
@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"weight_scaling_factor"
,
w_scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Although the
linear
_method is propagated to all layers,
# Although the
quant
_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if
not
hasattr
(
layer
,
"weight_scaling_factor"
):
return
qweight
,
weight_scale
=
per_tensor
_quant
ize
(
layer
.
weight
)
qweight
,
weight_scale
=
ops
.
scaled_fp8
_quant
(
layer
.
weight
)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scaling_factor
.
data
.
copy_
(
weight_scale
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qinput
,
x_scale
=
per_tensor
_quant
ize
(
x
)
qinput
,
x_scale
=
ops
.
scaled_fp8
_quant
(
x
)
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
...
...
@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
,
)
return
output
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
float
]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val
,
max_val
=
tensor
.
aminmax
()
amax
=
min_val
.
abs
().
max
(
max_val
.
abs
())
scale
=
finfo
.
max
/
amax
.
clamp
(
min
=
1e-12
)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight
=
(
tensor
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
scale
=
scale
.
float
().
reciprocal
()
return
qweight
,
scale
vllm/model_executor/layers/quantization/gptq.py
View file @
a62aaf1d
...
...
@@ -7,10 +7,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQConfig
(
QuantizationConfig
):
...
...
@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
def
get_linear_method
(
self
)
->
"GPTQLinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -194,7 +197,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
exllama_state
=
exllama_state
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/marlin.py
View file @
a62aaf1d
...
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
MarlinConfig
(
QuantizationConfig
):
...
...
@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
def
get_linear_method
(
self
)
->
"MarlinLinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
MarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
_weights
(
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
a62aaf1d
...
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
...
...
@@ -51,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"SqueezeLLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
Linear
MethodBase
):
class
SqueezeLLMLinearMethod
(
Quantize
MethodBase
):
"""Linear method for SqueezeLLM.
Args:
...
...
@@ -112,7 +116,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/model_loader/loader.py
View file @
a62aaf1d
...
...
@@ -3,8 +3,7 @@ import copy
import
glob
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
)
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
torch
import
nn
...
...
@@ -13,6 +12,8 @@ from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
tensorizer_weights_iterator
)
...
...
@@ -24,9 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.linear
import
LinearMethodBase
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
...
...
@@ -34,11 +32,10 @@ _VISION_MODEL_CLASSES = [
logger
=
init_logger
(
__name__
)
def
_get_
linear_method
(
def
_get_
quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
"LinearMethodBase"
]:
"""Get the (maybe quantized) linear method."""
linear_method
=
None
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
...
...
@@ -55,9 +52,8 @@ def _get_linear_method(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
return
linear_method
return
quant_config
return
None
def
_get_model_initialization_kwargs
(
...
...
@@ -85,10 +81,10 @@ def _initialize_model(
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model_config
,
load_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
))
...
...
@@ -229,9 +225,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load"
,
True
)),
)
for
_
,
module
in
model
.
named_modules
():
linear_method
=
getattr
(
module
,
"linear_method"
,
None
)
if
linear_method
is
not
None
:
linear_method
.
process_weights_after_loading
(
module
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
return
model
.
eval
()
...
...
@@ -314,11 +312,11 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model
_config
,
self
.
load_config
)
quant_config
=
_get_quantization
_config
(
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
)
extra_kwargs
[
"
linear_method"
]
=
linear_method
extra_kwargs
[
"
quant_config"
]
=
quant_config
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
a62aaf1d
...
...
@@ -13,7 +13,8 @@ from transformers import PretrainedConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -251,7 +252,7 @@ class TensorizerAgent:
"""
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
linear_method
:
LinearMethodBase
,
**
extra_kwargs
):
quant_config
:
QuantizationConfig
,
**
extra_kwargs
):
if
tensorizer_load_fail
is
not
None
:
raise
ImportError
(
"Tensorizer is not installed. Please install tensorizer "
...
...
@@ -262,10 +263,10 @@ class TensorizerAgent:
self
.
tensorizer_args
=
(
self
.
tensorizer_config
.
_construct_tensorizer_args
())
self
.
extra_kwargs
=
extra_kwargs
if
extra_kwargs
.
get
(
"
linear_method
"
,
None
)
is
not
None
:
self
.
linear_method
=
extra_kwargs
[
"
linear_method
"
]
if
extra_kwargs
.
get
(
"
quant_config
"
,
None
)
is
not
None
:
self
.
quant_config
=
extra_kwargs
[
"
quant_config
"
]
else
:
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
()
def
_init_model
(
self
):
...
...
@@ -274,7 +275,7 @@ class TensorizerAgent:
with
no_init_or_tensor
():
return
self
.
tensorizer_config
.
model_class
(
config
=
model_args
,
linear_method
=
self
.
linear_method
,
quant_config
=
self
.
quant_config
,
**
self
.
extra_kwargs
)
def
_resize_lora_embeddings
(
self
):
...
...
vllm/model_executor/models/baichuan.py
View file @
a62aaf1d
...
...
@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding
:
str
,
rope_theta
:
float
=
10000
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
...
...
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
...
@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding
=
position_embedding
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
,
position_embedding
,
linear_method
)
BaiChuanDecoderLayer
(
config
,
position_embedding
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
config
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ALIBI"
,
quant_config
,
lora_config
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
vllm/model_executor/models/bloom.py
View file @
a62aaf1d
...
...
@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Create the alibi slopes and slice them.
...
...
@@ -129,21 +130,21 @@ class BloomMLP(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -158,17 +159,17 @@ class BloomBlock(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
,
linear_method
)
self
.
self_attention
=
BloomAttention
(
config
,
quant_config
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
,
linear_method
)
self
.
mlp
=
BloomMLP
(
config
,
quant_config
)
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
...
...
@@ -214,7 +215,7 @@ class BloomModel(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -229,7 +230,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
BloomBlock
(
config
,
linear_method
)
BloomBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
a62aaf1d
...
...
@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...
...
@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
activation_func
=
SiluAndMul
()
...
...
@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
...
...
@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
linear_method
)
self
.
self_attention
=
GLMAttention
(
config
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
...
...
@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
GLMMLP
(
config
,
linear_method
)
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
post_layer_norm
=
config
.
post_layer_norm
...
...
@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
,
linear_method
)
for
i
in
range
(
self
.
num_layers
)])
[
GLMBlock
(
config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
...
...
@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
linear_method
)
self
.
encoder
=
GLMTransformer
(
config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
...
...
@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
ChatGLMModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
a62aaf1d
...
...
@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
act_fn
=
SiluAndMul
()
...
...
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
CohereAttention
(
config
,
linear_method
=
linear_method
)
self
.
self_attn
=
CohereAttention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
CohereMLP
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
CohereMLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
eps
=
config
.
layer_norm_eps
)
...
...
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
CohereDecoderLayer
(
config
,
linear_method
=
linear_method
)
CohereDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
...
...
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
linear_method
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
...
...
vllm/model_executor/models/dbrx.py
View file @
a62aaf1d
...
...
@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
linear_method
=
None
,
quant_config
=
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
...
...
@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
attn
=
DbrxAttention
(
config
,
linear_method
)
self
.
attn
=
DbrxAttention
(
config
,
quant_config
)
self
.
norm_1
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_2
=
nn
.
LayerNorm
(
self
.
d_model
)
...
...
@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
linear_method
)
self
.
ffn
=
DbrxExperts
(
config
,
linear_method
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
...
...
@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
[
DbrxBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
...
...
@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
linear_method
)
self
.
transformer
=
DbrxModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
a62aaf1d
...
...
@@ -29,7 +29,8 @@ import torch
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
...
...
@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
super
().
__init__
(
config
=
config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment