Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36a08630
Unverified
Commit
36a08630
authored
Feb 13, 2025
by
Qubitium-ModelCloud
Committed by
GitHub
Feb 12, 2025
Browse files
[CORE] [QUANT] Support for GPTQModel's `dynamic` quantization per module override/control (#7086)
parent
2c2b560f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
281 additions
and
56 deletions
+281
-56
tests/quantization/test_gptq_dynamic.py
tests/quantization/test_gptq_dynamic.py
+68
-0
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+12
-13
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+3
-3
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+38
-9
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+47
-12
vllm/model_executor/layers/quantization/utils/gptq_utils.py
vllm/model_executor/layers/quantization/utils/gptq_utils.py
+94
-0
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+18
-18
No files found.
tests/quantization/test_gptq_dynamic.py
0 → 100644
View file @
36a08630
# SPDX-License-Identifier: Apache-2.0
"""Tests whether gptq models with dynamic quantized can be loaded.
Run `pytest tests/quantization/test_gptq_dynamic.py --forked`.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
)
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_dynamic_override
)
PROMPT
=
"On the surface of Mars, we found"
# The first layer is quantized using bits=4, group_size=128
# The second layer is quantized using bits=8, group_size=32
# All other layers (layer index >= 2) are not quantized
MODEL_QUANT
=
[
(
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
,
True
),
(
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
,
False
),
]
@
pytest
.
mark
.
parametrize
(
"model_id, use_marlin_kernel"
,
MODEL_QUANT
)
def
test_gptq_with_dynamic
(
vllm_runner
,
model_id
:
str
,
use_marlin_kernel
:
bool
):
vllm_model
=
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
linear_method_cls
=
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
for
name
,
submodule
in
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
.
named_modules
()):
if
name
==
"lm_head"
:
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
elif
name
==
'model.layers.0.self_attn.qkv_proj'
:
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
config
.
weight_bits
==
4
assert
config
.
group_size
==
128
assert
config
.
desc_act
elif
name
==
'model.layers.1.self_attn.qkv_proj'
:
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"bits"
)
==
8
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"group_size"
)
==
32
assert
not
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"desc_act"
)
elif
(
name
==
'model.layers.2.self_attn.qkv_proj'
or
name
==
'model.layers.2.mlp.gate_up_proj'
):
# All other layers (layer index >= 2) are not quantized
assert
isinstance
(
submodule
.
quant_method
,
UnquantizedLinearMethod
)
del
vllm_model
tests/quantization/test_lm_head.py
View file @
36a08630
...
...
@@ -3,7 +3,6 @@
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from
typing
import
Tuple
import
pytest
import
torch
...
...
@@ -17,31 +16,31 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
PROMPT
=
"On the surface of Mars, we found"
MODELS_QUANT
=
[(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse"
,
True
),
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
False
),
(
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
False
)]
MODELS_QUANT
=
[
(
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head"
,
True
),
(
"ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024"
,
False
),
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
False
),
(
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
False
)
]
@
pytest
.
mark
.
parametrize
(
"model_lm_head_quant"
,
MODELS_QUANT
)
@
pytest
.
mark
.
parametrize
(
"model_
id,
lm_head_quant
ized
"
,
MODELS_QUANT
)
def
test_lm_head
(
vllm_runner
,
model_lm_head_quant
:
Tuple
[
str
,
bool
],
model_id
:
str
,
lm_head_quantized
:
bool
,
)
->
None
:
model
,
lm_head_quantized
=
model_lm_head_quant
with
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
vllm_model
:
def
check_model
(
model
):
lm_head_layer
=
model
.
lm_head
if
lm_head_quantized
:
assert
isinstance
(
lm_head_layer
.
linear
_method
,
assert
isinstance
(
lm_head_layer
.
quant
_method
,
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
else
:
assert
isinstance
(
lm_head_layer
.
linear
_method
,
assert
isinstance
(
lm_head_layer
.
quant
_method
,
UnquantizedEmbeddingMethod
)
vllm_model
.
apply_model
(
check_model
)
...
...
vllm/lora/layers.py
View file @
36a08630
...
...
@@ -1039,7 +1039,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
linear
_method
.
apply
(
lm_head
,
hidden_states
)
logits
=
lm_head
.
quant
_method
.
apply
(
lm_head
,
hidden_states
)
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
...
...
vllm/model_executor/layers/logits_processor.py
View file @
36a08630
...
...
@@ -108,9 +108,9 @@ class LogitsProcessor(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
linear
_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
logits
=
lm_head
.
quant
_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
# Gather logits for TP
logits
=
self
.
_gather_logits
(
logits
)
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
36a08630
...
...
@@ -3,16 +3,17 @@
import
enum
from
enum
import
Enum
from
fractions
import
Fraction
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
...
...
@@ -32,7 +33,33 @@ class GPTQConfig(QuantizationConfig):
group_size
:
int
,
desc_act
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
Dict
[
str
,
Dict
[
str
,
Union
[
int
,
bool
]]],
)
->
None
:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self
.
dynamic
=
dynamic
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
...
...
@@ -47,7 +74,8 @@ class GPTQConfig(QuantizationConfig):
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
),"
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
"
)
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
), "
f
"dynamic=
{
self
.
dynamic
}
"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -68,19 +96,20 @@ class GPTQConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQConfig"
:
dynamic
=
cls
.
get_from_keys_or
(
config
,
[
"dynamic"
],
default
=
{})
dynamic
=
{}
if
dynamic
is
None
else
dynamic
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
,
dynamic
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQLinearMethod
(
self
)
return
None
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
class
ExllamaState
(
Enum
):
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
36a08630
...
...
@@ -9,17 +9,21 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
UnquantizedLinearMethod
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
UnquantizedEmbeddingMethod
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
...
...
@@ -47,12 +51,41 @@ class GPTQMarlinConfig(QuantizationConfig):
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
Dict
[
str
,
Dict
[
str
,
Union
[
int
,
bool
]]],
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self
.
dynamic
=
dynamic
self
.
weight_bits
=
weight_bits
self
.
is_sym
=
is_sym
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
...
...
@@ -68,7 +101,8 @@ class GPTQMarlinConfig(QuantizationConfig):
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
), "
f
"dynamic=
{
self
.
dynamic
}
"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -88,6 +122,9 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
dynamic
=
cls
.
get_from_keys_or
(
config
,
[
"dynamic"
],
default
=
{})
dynamic
=
{}
if
dynamic
is
None
else
dynamic
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
...
...
@@ -95,7 +132,7 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
lm_head_quantized
)
lm_head_quantized
,
dynamic
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
...
@@ -120,17 +157,15 @@ class GPTQMarlinConfig(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
]]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
GPTQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
,
UnquantizedLinearMethod
,
UnquantizedEmbeddingMethod
]]:
if
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
None
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinLinearMethod
)
@
classmethod
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
...
...
@@ -143,7 +178,7 @@ class GPTQMarlinConfig(QuantizationConfig):
if
quant_method
!=
"gptq"
:
return
False
#
If we cannot find the info needed in the config, cannot convert.
#
Marlin conversion is only valid if required properties are found
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
...
...
vllm/model_executor/layers/quantization/utils/gptq_utils.py
0 → 100644
View file @
36a08630
# SPDX-License-Identifier: Apache-2.0
import
re
from
copy
import
deepcopy
from
typing
import
Dict
,
Optional
,
Union
import
torch
from
vllm.config
import
QuantizationConfig
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
)
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def
override_config
(
config
:
QuantizationConfig
,
prefix
:
str
):
weight_bits
=
get_dynamic_override
(
config
,
prefix
,
"bits"
,
config
.
weight_bits
)
if
isinstance
(
weight_bits
,
int
):
config
.
weight_bits
=
weight_bits
group_size
=
get_dynamic_override
(
config
,
prefix
,
"group_size"
,
config
.
group_size
)
if
isinstance
(
group_size
,
int
):
config
.
group_size
=
group_size
desc_act
=
get_dynamic_override
(
config
,
prefix
,
"desc_act"
,
config
.
desc_act
)
if
isinstance
(
desc_act
,
bool
):
config
.
desc_act
=
desc_act
config
.
pack_factor
=
32
//
config
.
weight_bits
# packed into int32
if
config
.
get_name
()
==
"gptq_marlin"
:
is_sym
=
get_dynamic_override
(
config
,
prefix
,
"sym"
,
config
.
is_sym
)
if
isinstance
(
is_sym
,
bool
):
config
.
is_sym
=
is_sym
if
(
config
.
weight_bits
,
config
.
is_sym
)
not
in
config
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
f
"bits=
{
config
.
weight_bits
}
, sym=
{
config
.
is_sym
}
"
)
config
.
quant_type
=
config
.
TYPE_MAP
[(
config
.
weight_bits
,
config
.
is_sym
)]
elif
config
.
get_name
()
==
"gptq"
:
if
config
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
raise
ValueError
(
"Currently, only 2/3/4/8-bit weight quantization is "
f
"supported for GPTQ, but got
{
config
.
weight_bits
}
bits."
)
def
get_dynamic_override
(
config
:
QuantizationConfig
,
layer_name
:
str
,
key
:
Optional
[
str
]
=
None
,
default_value
:
Union
[
int
,
bool
,
None
]
=
None
)
->
Union
[
Dict
,
int
,
bool
,
None
]:
for
pattern
,
pattern_dict
in
config
.
dynamic
.
items
():
# Negative match: matched modules are excluded from quantized init
if
pattern
.
startswith
(
"-:"
):
if
re
.
match
(
pattern
.
removeprefix
(
"-:"
),
layer_name
):
return
False
# Positive match: matched modules have quant properties overrides
# base quant config
elif
re
.
match
(
pattern
.
removeprefix
(
"+:"
),
layer_name
):
if
key
is
None
:
return
pattern_dict
else
:
return
pattern_dict
.
get
(
key
,
default_value
)
return
default_value
def
get_linear_quant_method
(
config
:
QuantizationConfig
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
linear_method_cls
:
type
,
):
cloned_config
=
deepcopy
(
config
)
parallel_lm_head_quantized
=
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
if
isinstance
(
layer
,
LinearBase
)
or
parallel_lm_head_quantized
:
# False = skip module, None = no override, else = Positive match
if
get_dynamic_override
(
# noqa: E712
cloned_config
,
# noqa: E712
layer_name
=
prefix
)
==
False
:
# noqa: E712
if
parallel_lm_head_quantized
:
return
UnquantizedEmbeddingMethod
()
return
UnquantizedLinearMethod
()
if
prefix
:
# Dynamic per module/layer rules may override base config
override_config
(
cloned_config
,
prefix
=
prefix
)
return
linear_method_cls
(
cloned_config
)
return
None
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
36a08630
...
...
@@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
tp_size
)
self
.
embedding_dim
=
embedding_dim
linear
_method
=
None
quant
_method
=
None
if
quant_config
is
not
None
:
linear
_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
linear
_method
is
None
:
linear
_method
=
UnquantizedEmbeddingMethod
()
quant
_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
quant
_method
is
None
:
quant
_method
=
UnquantizedEmbeddingMethod
()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer
=
type
(
self
.
__class__
)
is
VocabParallelEmbedding
linear
_method_implements_embedding
=
method_has_implemented_embedding
(
type
(
linear
_method
))
if
is_embedding_layer
and
not
linear
_method_implements_embedding
:
quant
_method_implements_embedding
=
method_has_implemented_embedding
(
type
(
quant
_method
))
if
is_embedding_layer
and
not
quant
_method_implements_embedding
:
raise
NotImplementedError
(
f
"The class
{
type
(
linear
_method
).
__name__
}
must implement "
f
"The class
{
type
(
quant
_method
).
__name__
}
must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self
.
linear
_method
:
QuantizeMethodBase
=
linear
_method
self
.
quant
_method
:
QuantizeMethodBase
=
quant
_method
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -260,13 +260,13 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
shard_indices
.
added_vocab_end_index
-
self
.
shard_indices
.
added_vocab_start_index
)
self
.
linear
_method
.
create_weights
(
self
,
self
.
embedding_dim
,
[
self
.
num_embeddings_per_partition
],
self
.
embedding_dim
,
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
self
.
quant
_method
.
create_weights
(
self
,
self
.
embedding_dim
,
[
self
.
num_embeddings_per_partition
],
self
.
embedding_dim
,
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
...
...
@@ -412,8 +412,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
self
.
linear
_method
.
embedding
(
self
,
masked_input
.
long
())
output_parallel
=
self
.
quant
_method
.
embedding
(
self
,
masked_input
.
long
())
# Mask the output embedding.
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment