Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
458 additions
and
177 deletions
+458
-177
vllm/lora/punica.py
vllm/lora/punica.py
+21
-17
vllm/lora/request.py
vllm/lora/request.py
+1
-0
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+5
-0
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+2
-2
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-6
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+30
-14
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-2
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+16
-5
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+8
-2
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+30
-5
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+8
-7
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+7
-6
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+4
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+11
-11
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+221
-5
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+26
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+36
-78
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+14
-5
No files found.
vllm/lora/punica.py
View file @
539aa992
...
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
...
...
@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
...
...
@@ -52,7 +52,7 @@ def compute_meta(
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
...
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype
=
torch
.
long
,
device
=
device
)
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
no_lora
=
False
...
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
...
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
@
property
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions
.
2. seq_lengths: Tensor of sequence lengths
.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
self
.
batch_size
,
self
.
max_length
,
self
.
token_nums
)
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
...
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA
.
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
...
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices
.
"""
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
...
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA
.
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
...
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora
.
"""
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
...
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org
=
y
...
...
vllm/lora/request.py
View file @
539aa992
...
...
@@ -28,6 +28,7 @@ class LoRARequest(
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
base_model_name
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__struct_fields__
:
...
...
vllm/model_executor/custom_op.py
View file @
539aa992
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
...
...
@@ -53,6 +54,10 @@ class CustomOp(nn.Module):
def
dispatch_forward
(
self
):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_TEST_COMPILE_NO_CUSTOM_OPS
:
return
self
.
forward_native
if
is_hip
():
return
self
.
forward_hip
elif
is_cpu
():
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
539aa992
...
...
@@ -67,9 +67,9 @@ class BaseLogitsProcessor:
instruction
=
self
.
_guide
.
get_next_instruction
(
state
=
self
.
_fsm_state
[
seq_id
])
if
type
(
instruction
)
==
Generate
:
if
type
(
instruction
)
==
Generate
:
# noqa: E721
allowed_tokens
=
instruction
.
tokens
elif
type
(
instruction
)
==
Write
:
elif
type
(
instruction
)
==
Write
:
# noqa: E721
# TODO: support fast forward tokens
allowed_tokens
=
[
instruction
.
tokens
[
0
]]
else
:
...
...
vllm/model_executor/layers/activation.py
View file @
539aa992
...
...
@@ -124,9 +124,7 @@ class NewGELU(CustomOp):
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_new
(
out
,
x
)
return
out
return
ops
.
gelu_new
(
x
)
class
FastGELU
(
CustomOp
):
...
...
@@ -146,9 +144,7 @@ class FastGELU(CustomOp):
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_fast
(
out
,
x
)
return
out
return
ops
.
gelu_fast
(
x
)
class
QuickGELU
(
CustomOp
):
...
...
@@ -165,6 +161,13 @@ class QuickGELU(CustomOp):
ops
.
gelu_quick
(
out
,
x
)
return
out
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_quick
(
out
,
x
)
return
out
# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
539aa992
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.scalar_type
import
scalar_types
def
single_marlin_moe
(
...
...
@@ -18,7 +19,9 @@ def single_marlin_moe(
perm
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
torch
.
Tensor
:
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_bits
:
int
=
8
,
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
...
...
@@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
...
...
@@ -48,10 +52,11 @@ def single_marlin_moe(
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
2
N
=
w
.
shape
[
2
]
//
(
num_bits
//
2
)
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
...
...
@@ -76,10 +81,13 @@ def single_marlin_moe(
device
=
"cuda"
,
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
g_idx
,
perm
,
workspace
,
scalar_type
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
...
...
@@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -122,6 +131,7 @@ def fused_marlin_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
...
...
@@ -131,13 +141,14 @@ def fused_marlin_moe(
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
(
num_bits
//
2
)
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
...
...
@@ -165,6 +176,9 @@ def fused_marlin_moe(
device
=
"cuda"
,
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
...
...
@@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1
,
perm1
,
workspace
,
scalar_type
,
M
,
2
*
N
,
K
,
...
...
@@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2
,
perm2
,
workspace
,
scalar_type
,
M
,
K
,
N
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
539aa992
...
...
@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
)
,
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
539aa992
...
...
@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
==
"CompressedTensors
WNA16
MoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
...
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
...
...
vllm/model_executor/layers/layernorm.py
View file @
539aa992
...
...
@@ -99,14 +99,11 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
out
,
return
ops
.
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
out
def
extra_repr
(
self
)
->
str
:
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
...
...
vllm/model_executor/layers/linear.py
View file @
539aa992
...
...
@@ -549,6 +549,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
...
...
@@ -918,8 +921,13 @@ class QKVParallelLinear(ColumnParallelLinear):
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
...
...
@@ -1019,6 +1027,7 @@ class RowParallelLinear(LinearBase):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -1034,7 +1043,9 @@ class RowParallelLinear(LinearBase):
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
input_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
input_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
539aa992
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from
typing
import
Optional
...
...
@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
activation
:
Optional
[
str
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
"""
...
...
@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
)
activation_bool
,
conv_state_indices
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
539aa992
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
import
torch
import
triton
...
...
@@ -27,6 +28,10 @@ else:
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_STATE_BATCH_INDICES"
:
lambda
args
:
args
[
"state_batch_indices_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
@
triton
.
jit
...
...
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
D_ptr
,
z_ptr
,
out_ptr
,
state_batch_indices_ptr
,
# Matrix dimensions
batch
,
nheads
,
...
...
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_STATE_BATCH_INDICES
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if
HAS_STATE_BATCH_INDICES
:
state_batch_indices_ptr
+=
pid_b
state_batch_idx
=
tl
.
load
(
state_batch_indices_ptr
)
state_ptr
+=
(
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
)
else
:
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
...
...
@@ -177,7 +195,8 @@ def selective_state_update(state,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
dt_softplus
=
False
,
state_batch_indices
=
None
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
...
@@ -211,7 +230,10 @@ def selective_state_update(state,
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
_
,
nheads
,
dim
,
dstate
=
state
.
shape
batch
=
x
.
shape
[
0
]
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
...
...
@@ -225,6 +247,8 @@ def selective_state_update(state,
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
if
state_batch_indices
is
not
None
:
assert
state_batch_indices
.
shape
==
(
batch
,
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
...
...
@@ -249,6 +273,7 @@ def selective_state_update(state,
D
,
z
,
out
,
state_batch_indices
,
batch
,
nheads
,
dim
,
...
...
@@ -336,7 +361,7 @@ def selective_scan_fn(u,
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
x
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
out
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
z
is
None
:
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
539aa992
...
...
@@ -7,10 +7,11 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
...
...
@@ -110,9 +111,9 @@ class AWQMarlinConfig(QuantizationConfig):
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
has_zp
=
quant_config
.
get
(
"zero_point"
,
None
)
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
has_zp
=
quant_config
.
get
(
"zero_point"
)
if
quant_method
!=
"awq"
:
return
False
...
...
@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
replace_
tenso
r
(
layer
,
"qweight"
,
marlin_qweight
)
replace_
paramete
r
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
...
...
@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_
tenso
r
(
layer
,
"scales"
,
marlin_scales
)
replace_
paramete
r
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
...
...
@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
replace_
tenso
r
(
layer
,
"qzeros"
,
marlin_zp
)
replace_
paramete
r
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
539aa992
...
...
@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
c
=
accumulator
.
to
(
c_ptr
.
type
.
element_ty
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
pid_z
*
N
*
M
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptrs
,
c
,
mask
=
c_mask
)
# qweights - [K , M // 8], int32
...
...
@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
split_k_iters
,
)
result
=
torch
.
zeros
((
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
result
=
torch
.
zeros
((
split_k_iters
,
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
...
...
@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
BLOCK_SIZE_K
=
block_size_k
,
SPLIT_K
=
split_k_iters
)
result
=
result
.
sum
(
0
)
return
result
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
539aa992
...
...
@@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.4
2
.0"
:
if
bitsandbytes
.
__version__
<
"0.4
4
.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.4
2
.0."
)
"install bitsandbytes>=0.4
4
.0."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.4
2
.0 via "
"`pip install bitsandbytes>=0.4
2
.0` to use "
raise
ImportError
(
"Please install bitsandbytes>=0.4
4
.0 via "
"`pip install bitsandbytes>=0.4
4
.0` to use "
"bitsandbytes quantizer."
)
from
err
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
539aa992
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
import
torch
from
pydantic
import
BaseModel
...
...
@@ -73,14 +73,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
)
return
None
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
ignore
=
cast
(
List
[
str
]
,
config
.
get
(
"ignore"
)
)
quant_format
=
cast
(
str
,
config
.
get
(
"format"
)
)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
...
...
@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
_tuple
=
current_platform
.
get_device_capability
()
if
capability
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
_tuple
is
not
None
:
capability
=
capability
_tuple
.
to_int
()
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
...
...
@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
if
not
(
is_symmetric_weight
and
is_static_weight
if
not
(
is_symmetric_weight
and
is_static_weight
# noqa: SIM103
and
is_per_tensor_or_channel_weight
):
return
False
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
539aa992
...
...
@@ -5,10 +5,16 @@ from typing import Callable, List, Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
CompressionFormat
,
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
,
print_warning_once
class
GPTQMarlinState
(
Enum
):
...
...
@@ -16,11 +22,219 @@ class GPTQMarlinState(Enum):
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
__all__
=
[
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
not
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
):
raise
ValueError
(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
...
...
@@ -38,10 +252,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
==
4
):
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for 4 bits"
)
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
...
...
@@ -292,4 +507,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
num_bits
=
self
.
num_bits
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
539aa992
...
...
@@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.utils
import
is_hip
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
...
...
@@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths
=
layer
.
logical_widths
,
)
if
is_hip
():
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
else
:
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
539aa992
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
,
Set
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.kernels
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
marlin_repeat_scales_on_all_ranks
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
...
...
@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
...
...
@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
Set
[
str
]
=
set
()
def
__init__
(
self
,
strategy
:
str
,
...
...
@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
return
80
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
in
put_size
:
int
,
output_partition_sizes
:
List
[
int
],
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
out
put_size
:
int
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
group_size
,
zero_points
=
False
,
has_g_idx
=
self
.
has_g_idx
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsWNA16"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# If group_size is -1, we are in channelwise case.
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
...
...
@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
group_size
=
group_size
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_s_param_name
=
"weight_scale"
,
w_zp_param_name
=
None
,
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from
marlin forma
t. Handle repacking here.
# different from
the format the kernel may wan
t. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
weight_packed
.
device
# Allocate marlin workspace.
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Handle sorting for activation reordering if needed.
if
self
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
weight_g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"weight_g_idx"
,
g_idx
)
else
:
layer
.
weight_g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
# Update for kernel
layer
.
weight_packed
=
torch
.
nn
.
Parameter
(
layer
.
weight_packed
.
t
().
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
requires_grad
=
False
)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_type
.
size_bits
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
,
size_k
=
(
layer
.
input_size
if
self
.
has_g_idx
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_gptq_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
weight_g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
,
bias
=
bias
)
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
539aa992
...
...
@@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
)
apply_fp8_linear
,
normalize_e4m3fn_to_e4m3fnuz
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
...
...
@@ -32,9 +33,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
self
.
use_marlin
=
not
current_platform
.
has_device_capability
(
89
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -127,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
quant_config
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment