Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
683e3cb9
Unverified
Commit
683e3cb9
authored
Jul 20, 2024
by
Robert Shaw
Committed by
GitHub
Jul 20, 2024
Browse files
[ Misc ] `fbgemm` checkpoints (#6559)
parent
9042d683
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
220 additions
and
39 deletions
+220
-39
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
...a-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
+2
-2
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml
...s/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml
+11
-0
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-0
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+16
-10
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+2
-2
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-2
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+3
-2
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+5
-1
vllm/model_executor/layers/quantization/deepspeedfp.py
vllm/model_executor/layers/quantization/deepspeedfp.py
+2
-3
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+158
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-2
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+2
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+2
-3
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+2
-3
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
View file @
683e3cb9
...
@@ -4,8 +4,8 @@ tasks:
...
@@ -4,8 +4,8 @@ tasks:
-
name
:
"
gsm8k"
-
name
:
"
gsm8k"
metrics
:
metrics
:
-
name
:
"
exact_match,strict-match"
-
name
:
"
exact_match,strict-match"
value
:
0.7
69
value
:
0.7
52
-
name
:
"
exact_match,flexible-extract"
-
name
:
"
exact_match,flexible-extract"
value
:
0.7
69
value
:
0.7
54
limit
:
1000
limit
:
1000
num_fewshot
:
5
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml
0 → 100644
View file @
683e3cb9
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.753
-
name
:
"
exact_match,flexible-extract"
value
:
0.753
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
View file @
683e3cb9
...
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
...
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
done
lm_eval
--model
vllm
\
lm_eval
--model
vllm
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,
add_bos_token
=
true
,
distributed_executor_backend
=
"ray"
,trust_remote_code
=
true
,max_model_len
=
4096
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,distributed_executor_backend
=
"ray"
,trust_remote_code
=
true
,max_model_len
=
4096
\
--tasks
gsm8k
--num_fewshot
$FEWSHOT
--limit
$LIMIT
\
--tasks
gsm8k
--num_fewshot
$FEWSHOT
--limit
$LIMIT
\
--batch_size
$BATCH_SIZE
--batch_size
$BATCH_SIZE
vllm/_custom_ops.py
View file @
683e3cb9
...
@@ -315,6 +315,8 @@ def scaled_fp8_quant(
...
@@ -315,6 +315,8 @@ def scaled_fp8_quant(
Args:
Args:
input: The input tensor to be quantized to FP8
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
use_per_token_if_dynamic: Whether to do per_tensor or per_token
...
...
vllm/attention/layer.py
View file @
683e3cb9
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
cache_config
is
not
None
:
if
cache_config
is
not
None
:
...
@@ -56,7 +57,7 @@ class Attention(nn.Module):
...
@@ -56,7 +57,7 @@ class Attention(nn.Module):
self
.
_k_scale
=
1.0
self
.
_k_scale
=
1.0
self
.
_v_scale
=
1.0
self
.
_v_scale
=
1.0
quant_method
=
quant_config
.
get_quant_method
(
quant_method
=
quant_config
.
get_quant_method
(
self
)
if
quant_config
else
None
self
,
prefix
=
prefix
)
if
quant_config
else
None
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
assert
isinstance
(
quant_method
,
Fp8KVCacheMethod
)
assert
isinstance
(
quant_method
,
Fp8KVCacheMethod
)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# TODO (mgoin): kv cache dtype should be specified in the FP8
...
...
vllm/config.py
View file @
683e3cb9
...
@@ -251,7 +251,7 @@ class ModelConfig:
...
@@ -251,7 +251,7 @@ class ModelConfig:
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
if
(
self
.
quantization
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"compressed_tensors"
)):
"fbgemm_fp8"
,
"compressed_tensors"
)):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
683e3cb9
...
@@ -182,7 +182,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -182,7 +182,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
())
UnquantizedFusedMoEMethod
())
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
.
quant_method
.
create_weights
(
...
...
vllm/model_executor/layers/linear.py
View file @
683e3cb9
...
@@ -141,6 +141,7 @@ class LinearBase(torch.nn.Module):
...
@@ -141,6 +141,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -155,7 +156,8 @@ class LinearBase(torch.nn.Module):
...
@@ -155,7 +156,8 @@ class LinearBase(torch.nn.Module):
self
.
quant_method
:
Optional
[
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -182,9 +184,13 @@ class ReplicatedLinear(LinearBase):
...
@@ -182,9 +184,13 @@ class ReplicatedLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
quant_config
)
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
)
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -258,9 +264,9 @@ class ColumnParallelLinear(LinearBase):
...
@@ -258,9 +264,9 @@ class ColumnParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
...
@@ -370,7 +376,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -370,7 +376,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
prefix
:
str
=
""
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -514,7 +520,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -514,7 +520,7 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
prefix
:
str
=
""
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -707,9 +713,9 @@ class RowParallelLinear(LinearBase):
...
@@ -707,9 +713,9 @@ class RowParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
683e3cb9
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig
)
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
...
@@ -24,6 +25,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -24,6 +25,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
683e3cb9
...
@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
...
@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
out_group_size
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AQLMLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"AQLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
AQLMLinearMethod
(
self
)
return
AQLMLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
683e3cb9
...
@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
...
@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
return
cls
(
weight_bits
,
group_size
,
zero_point
)
return
cls
(
weight_bits
,
group_size
,
zero_point
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AWQLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"AWQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
AWQLinearMethod
(
self
)
return
AWQLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
683e3cb9
...
@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
...
@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
return
default
return
default
@
abstractmethod
@
abstractmethod
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
"""Get the quantize method to use for the quantized layer.
Args:
Args:
layer: The layer for the quant method.
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
Returns:
The quantize method. None if the given layer doesn't support quant
The quantize method. None if the given layer doesn't support quant
method.
method.
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
683e3cb9
...
@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
return
cls
(
adapter_name
,
target_modules
)
return
cls
(
adapter_name
,
target_modules
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
BitsAndBytesLinearMethod
(
self
)
return
BitsAndBytesLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
683e3cb9
...
@@ -44,8 +44,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -44,8 +44,12 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
return
"compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"CompressedTensorsLinearMethod"
]:
)
->
Optional
[
"CompressedTensorsLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/deepspeedfp.py
View file @
683e3cb9
...
@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
...
@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
"quantize_config.json"
,
"quantize_config.json"
,
]
]
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"DeepSpeedFPLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"DeepSpeedFPLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
DeepSpeedFPLinearMethod
(
self
)
return
DeepSpeedFPLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
0 → 100644
View file @
683e3cb9
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
def
__init__
(
self
,
ignore_list
:
List
[
str
],
input_scale_ub
:
float
):
self
.
ignore_list
=
ignore_list
self
.
input_scale_ub
=
input_scale_ub
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fbgemm_fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
float16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
89
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"FBGEMMFp8Config"
:
ignore_list
=
cls
.
get_from_keys
(
config
,
[
"modules_to_not_convert"
])
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
_is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
self
.
ignore_list
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
self
.
ignore_list
assert
is_skipped
is
not
None
return
is_skipped
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
self
.
_is_layer_skipped
(
prefix
):
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
**
extra_weight_attrs
,
})
# WEIGHT SCALE
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE UPPER BOUND
input_scale_ub
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
(
self
.
quant_config
.
input_scale_ub
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
input_scale_ub
=
input_scale_ub
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
True
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
683e3cb9
...
@@ -66,8 +66,8 @@ class Fp8Config(QuantizationConfig):
...
@@ -66,8 +66,8 @@ class Fp8Config(QuantizationConfig):
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
activation_scheme
=
activation_scheme
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
683e3cb9
...
@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
...
@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
default
=
False
)
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQLinearMethod
(
self
)
return
GPTQLinearMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
683e3cb9
...
@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
" faster inference"
)
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQMarlinLinearMethod
(
self
)
return
GPTQMarlinLinearMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
683e3cb9
...
@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"GPTQMarlin24LinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlin24LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlin24LinearMethod
(
self
)
return
GPTQMarlin24LinearMethod
(
self
)
return
None
return
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment