Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
458 additions
and
177 deletions
+458
-177
vllm/lora/punica.py
vllm/lora/punica.py
+21
-17
vllm/lora/request.py
vllm/lora/request.py
+1
-0
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+5
-0
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+2
-2
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-6
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+30
-14
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-2
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+16
-5
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+8
-2
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+30
-5
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+8
-7
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+7
-6
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+4
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+11
-11
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+221
-5
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+26
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+36
-78
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+14
-5
No files found.
vllm/lora/punica.py
View file @
539aa992
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def
compute_meta
(
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
"""
Get the information required for the sgmv kernel. With the features:
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
1. If consecutive requests in the batch use the same LoRA, this function
...
@@ -43,7 +43,7 @@ def compute_meta(
...
@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# -1 means no lora should be applied. Use `no_lora` to determine whether
...
@@ -52,7 +52,7 @@ def compute_meta(
...
@@ -52,7 +52,7 @@ def compute_meta(
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
# TODO see if this can be vectorized
...
@@ -178,7 +178,7 @@ def convert_mapping(
...
@@ -178,7 +178,7 @@ def convert_mapping(
class
PunicaWrapper
:
class
PunicaWrapper
:
"""
"""
PunicaWrapper is designed to manage and provide metadata for the punica
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function
is to maintain the state information for
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
Multi-LoRA, and to provide the interface for the punica kernel.
"""
"""
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
self
.
max_length
:
int
=
0
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
is_prefill
=
False
self
.
no_lora
=
False
self
.
no_lora
=
False
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor
)
long_lora_offsets_tensor
)
else
:
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
b_seq_start_tensor
)
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor
)
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
max_length
=
max_length
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
self
.
no_lora
=
no_lora
@
property
@
property
def
prefill_metadata
(
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
]:
"""
"""
This property provides a convenient way to access the necessary
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
1. seq_start_locs: Tensor of sequence start positions
.
2. seq_lengths: Tensor of sequence lengths
2. seq_lengths: Tensor of sequence lengths
.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
self
.
batch_size
,
self
.
max_length
,
self
.
token_nums
)
@
property
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property is used to access the lora indices specifically for
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA
.
"""
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
@
property
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices
.
"""
"""
indices_padded_len
=
self
.
indices_len
[
2
]
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to the indices used for lora embeddings,
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA
.
"""
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to the indices used for long context
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora
.
"""
"""
long_lora_len
=
self
.
indices_len
[
4
]
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
"""
y_org
=
y
y_org
=
y
...
...
vllm/lora/request.py
View file @
539aa992
...
@@ -28,6 +28,7 @@ class LoRARequest(
...
@@ -28,6 +28,7 @@ class LoRARequest(
lora_path
:
str
=
""
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
long_lora_max_len
:
Optional
[
int
]
=
None
base_model_name
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__struct_fields__
:
if
'lora_local_path'
in
self
.
__struct_fields__
:
...
...
vllm/model_executor/custom_op.py
View file @
539aa992
import
torch.nn
as
nn
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
...
@@ -53,6 +54,10 @@ class CustomOp(nn.Module):
...
@@ -53,6 +54,10 @@ class CustomOp(nn.Module):
def
dispatch_forward
(
self
):
def
dispatch_forward
(
self
):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_TEST_COMPILE_NO_CUSTOM_OPS
:
return
self
.
forward_native
if
is_hip
():
if
is_hip
():
return
self
.
forward_hip
return
self
.
forward_hip
elif
is_cpu
():
elif
is_cpu
():
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
539aa992
...
@@ -67,9 +67,9 @@ class BaseLogitsProcessor:
...
@@ -67,9 +67,9 @@ class BaseLogitsProcessor:
instruction
=
self
.
_guide
.
get_next_instruction
(
instruction
=
self
.
_guide
.
get_next_instruction
(
state
=
self
.
_fsm_state
[
seq_id
])
state
=
self
.
_fsm_state
[
seq_id
])
if
type
(
instruction
)
==
Generate
:
if
type
(
instruction
)
==
Generate
:
# noqa: E721
allowed_tokens
=
instruction
.
tokens
allowed_tokens
=
instruction
.
tokens
elif
type
(
instruction
)
==
Write
:
elif
type
(
instruction
)
==
Write
:
# noqa: E721
# TODO: support fast forward tokens
# TODO: support fast forward tokens
allowed_tokens
=
[
instruction
.
tokens
[
0
]]
allowed_tokens
=
[
instruction
.
tokens
[
0
]]
else
:
else
:
...
...
vllm/model_executor/layers/activation.py
View file @
539aa992
...
@@ -124,9 +124,7 @@ class NewGELU(CustomOp):
...
@@ -124,9 +124,7 @@ class NewGELU(CustomOp):
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
return
ops
.
gelu_new
(
x
)
ops
.
gelu_new
(
out
,
x
)
return
out
class
FastGELU
(
CustomOp
):
class
FastGELU
(
CustomOp
):
...
@@ -146,9 +144,7 @@ class FastGELU(CustomOp):
...
@@ -146,9 +144,7 @@ class FastGELU(CustomOp):
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
return
ops
.
gelu_fast
(
x
)
ops
.
gelu_fast
(
out
,
x
)
return
out
class
QuickGELU
(
CustomOp
):
class
QuickGELU
(
CustomOp
):
...
@@ -165,6 +161,13 @@ class QuickGELU(CustomOp):
...
@@ -165,6 +161,13 @@ class QuickGELU(CustomOp):
ops
.
gelu_quick
(
out
,
x
)
ops
.
gelu_quick
(
out
,
x
)
return
out
return
out
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_quick
(
out
,
x
)
return
out
# TODO implement forward_xpu for QuickGELU
# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
539aa992
...
@@ -7,18 +7,21 @@ import torch
...
@@ -7,18 +7,21 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.scalar_type
import
scalar_types
def
single_marlin_moe
(
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
torch
.
Tensor
:
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_bits
:
int
=
8
,
)
->
torch
.
Tensor
:
"""
"""
This function computes the multiplication of hidden_states with expert
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
weights used in Marlin MoE, using weights w and top-k gating mechanism.
...
@@ -36,6 +39,7 @@ def single_marlin_moe(
...
@@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -48,10 +52,11 @@ def single_marlin_moe(
...
@@ -48,10 +52,11 @@ def single_marlin_moe(
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
hidden_states
.
dtype
==
torch
.
float16
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
2
N
=
w
.
shape
[
2
]
//
(
num_bits
//
2
)
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
renormalize
)
...
@@ -76,10 +81,13 @@ def single_marlin_moe(
...
@@ -76,10 +81,13 @@ def single_marlin_moe(
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
False
)
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
g_idx
,
perm
,
workspace
,
scalar_type
,
M
,
N
,
K
,
True
,
E
,
topk
,
False
)
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
...
@@ -98,6 +106,7 @@ def fused_marlin_moe(
...
@@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -122,6 +131,7 @@ def fused_marlin_moe(
...
@@ -122,6 +131,7 @@ def fused_marlin_moe(
w1.
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -131,13 +141,14 @@ def fused_marlin_moe(
...
@@ -131,13 +141,14 @@ def fused_marlin_moe(
0
],
"Number of tokens mismatch"
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
(
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
num_bits
//
2
)
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
hidden_states
.
dtype
==
torch
.
float16
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
E
=
w1
.
shape
[
0
]
...
@@ -165,6 +176,9 @@ def fused_marlin_moe(
...
@@ -165,6 +176,9 @@ def fused_marlin_moe(
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
False
)
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
intermediate_cache2
=
torch
.
empty
(
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
@@ -181,6 +195,7 @@ def fused_marlin_moe(
...
@@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1
,
g_idx1
,
perm1
,
perm1
,
workspace
,
workspace
,
scalar_type
,
M
,
M
,
2
*
N
,
2
*
N
,
K
,
K
,
...
@@ -204,6 +219,7 @@ def fused_marlin_moe(
...
@@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2
,
g_idx2
,
perm2
,
perm2
,
workspace
,
workspace
,
scalar_type
,
M
,
M
,
K
,
K
,
N
,
N
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
539aa992
...
@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
.
to
(
torch
.
int32
)
return
topk_weights
.
to
(
torch
.
float32
)
,
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
539aa992
...
@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
==
"CompressedTensors
WNA16
MoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/model_executor/layers/layernorm.py
View file @
539aa992
...
@@ -99,14 +99,11 @@ class RMSNorm(CustomOp):
...
@@ -99,14 +99,11 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
x
,
residual
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
return
ops
.
rms_norm
(
ops
.
rms_norm
(
out
,
x
,
x
,
self
.
weight
.
data
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
out
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
...
...
vllm/model_executor/layers/linear.py
View file @
539aa992
...
@@ -549,8 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -549,8 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
# bitsandbytes loads the weights of the specific portion
shard_size
)
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
...
@@ -918,8 +921,13 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -918,8 +921,13 @@ class QKVParallelLinear(ColumnParallelLinear):
else
:
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
# Special case for for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
...
@@ -1019,6 +1027,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1019,6 +1027,7 @@ class RowParallelLinear(LinearBase):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
# Special case for GGUF
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
@@ -1034,7 +1043,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1034,7 +1043,9 @@ class RowParallelLinear(LinearBase):
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
param_data
=
param
.
data
if
input_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
input_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
input_dim
]
shard_size
=
param_data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
539aa992
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from
typing
import
Optional
from
typing
import
Optional
...
@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
activation
:
Optional
[
str
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
"""
x: (batch, dim)
x: (batch, dim)
conv_state: (batch, dim, width)
conv_state: (batch, dim, width)
weight: (dim, width)
weight: (dim, width)
bias: (dim,)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
out: (batch, dim)
"""
"""
...
@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
)
activation_bool
,
conv_state_indices
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
539aa992
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
import
torch
import
torch
import
triton
import
triton
...
@@ -27,6 +28,10 @@ else:
...
@@ -27,6 +28,10 @@ else:
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_STATE_BATCH_INDICES"
:
lambda
args
:
args
[
"state_batch_indices_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
@
triton
.
jit
@
triton
.
jit
...
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
...
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
D_ptr
,
D_ptr
,
z_ptr
,
z_ptr
,
out_ptr
,
out_ptr
,
state_batch_indices_ptr
,
# Matrix dimensions
# Matrix dimensions
batch
,
batch
,
nheads
,
nheads
,
...
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
...
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_STATE_BATCH_INDICES
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if
HAS_STATE_BATCH_INDICES
:
state_batch_indices_ptr
+=
pid_b
state_batch_idx
=
tl
.
load
(
state_batch_indices_ptr
)
state_ptr
+=
(
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
)
else
:
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
if
HAS_DT_BIAS
:
...
@@ -177,7 +195,8 @@ def selective_state_update(state,
...
@@ -177,7 +195,8 @@ def selective_state_update(state,
D
=
None
,
D
=
None
,
z
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
dt_softplus
=
False
,
state_batch_indices
=
None
):
"""
"""
Argument:
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
@@ -211,7 +230,10 @@ def selective_state_update(state,
...
@@ -211,7 +230,10 @@ def selective_state_update(state,
z
=
z
.
unsqueeze
(
1
)
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
_
,
nheads
,
dim
,
dstate
=
state
.
shape
batch
=
x
.
shape
[
0
]
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
...
@@ -225,6 +247,8 @@ def selective_state_update(state,
...
@@ -225,6 +247,8 @@ def selective_state_update(state,
assert
z
.
shape
==
x
.
shape
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
if
state_batch_indices
is
not
None
:
assert
state_batch_indices
.
shape
==
(
batch
,
)
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
...
@@ -249,6 +273,7 @@ def selective_state_update(state,
...
@@ -249,6 +273,7 @@ def selective_state_update(state,
D
,
D
,
z
,
z
,
out
,
out
,
state_batch_indices
,
batch
,
batch
,
nheads
,
nheads
,
dim
,
dim
,
...
@@ -336,8 +361,8 @@ def selective_scan_fn(u,
...
@@ -336,8 +361,8 @@ def selective_scan_fn(u,
x
[:,
:,
0
,
0
::
2
]
=
1
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
x
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
out
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
z
is
None
:
if
z
is
None
:
return
out
if
not
return_last_state
else
(
out
,
last_state
)
return
out
if
not
return_last_state
else
(
out
,
last_state
)
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
539aa992
...
@@ -7,10 +7,11 @@ from vllm.logger import init_logger
...
@@ -7,10 +7,11 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
...
@@ -110,9 +111,9 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -110,9 +111,9 @@ class AWQMarlinConfig(QuantizationConfig):
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
)
has_zp
=
quant_config
.
get
(
"zero_point"
,
None
)
has_zp
=
quant_config
.
get
(
"zero_point"
)
if
quant_method
!=
"awq"
:
if
quant_method
!=
"awq"
:
return
False
return
False
...
@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
replace_
tenso
r
(
layer
,
"qweight"
,
marlin_qweight
)
replace_
paramete
r
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
marlin_scales
=
marlin_permute_scales
(
...
@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
group_size
=
self
.
quant_config
.
group_size
)
replace_
tenso
r
(
layer
,
"scales"
,
marlin_scales
)
replace_
paramete
r
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
marlin_zp
=
awq_to_marlin_zero_points
(
...
@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k
=
layer
.
num_groups
,
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
replace_
tenso
r
(
layer
,
"qzeros"
,
marlin_zp
)
replace_
paramete
r
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
539aa992
...
@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
...
@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
c
=
accumulator
.
to
(
c_ptr
.
type
.
element_ty
)
c
=
accumulator
.
to
(
c_ptr
.
type
.
element_ty
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
pid_z
*
N
*
M
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptrs
,
c
,
mask
=
c_mask
)
# qweights - [K , M // 8], int32
# qweights - [K , M // 8], int32
...
@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
...
@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
split_k_iters
,
split_k_iters
,
)
)
result
=
torch
.
zeros
((
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
result
=
torch
.
zeros
((
split_k_iters
,
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
# A = input, B = qweight, C = result
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
# A = M x K, B = K x N, C = M x N
...
@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
...
@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
BLOCK_SIZE_K
=
block_size_k
,
BLOCK_SIZE_K
=
block_size_k
,
SPLIT_K
=
split_k_iters
)
SPLIT_K
=
split_k_iters
)
result
=
result
.
sum
(
0
)
return
result
return
result
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
539aa992
...
@@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
...
@@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
try
:
try
:
import
bitsandbytes
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.4
2
.0"
:
if
bitsandbytes
.
__version__
<
"0.4
4
.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.4
2
.0."
)
"install bitsandbytes>=0.4
4
.0."
)
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.4
2
.0 via "
raise
ImportError
(
"Please install bitsandbytes>=0.4
4
.0 via "
"`pip install bitsandbytes>=0.4
2
.0` to use "
"`pip install bitsandbytes>=0.4
4
.0` to use "
"bitsandbytes quantizer."
)
from
err
"bitsandbytes quantizer."
)
from
err
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
539aa992
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
import
torch
import
torch
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
...
@@ -73,14 +73,14 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -73,14 +73,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
)
return
None
return
None
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
ignore
=
cast
(
List
[
str
]
,
config
.
get
(
"ignore"
)
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
quant_format
=
cast
(
str
,
config
.
get
(
"format"
)
)
# The quant_config has multiple config_groups, each containing
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# an input_activations key with details about how the activations are
...
@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
_tuple
=
current_platform
.
get_device_capability
()
if
capability
is
not
None
:
if
capability
_tuple
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
_tuple
.
to_int
()
supported
=
capability
>=
min_capability
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
if
error
and
not
supported
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
])
if
not
(
is_symmetric_weight
and
is_static_weight
if
not
(
is_symmetric_weight
and
is_static_weight
# noqa: SIM103
and
is_per_tensor_or_channel_weight
):
and
is_per_tensor_or_channel_weight
):
return
False
return
False
...
@@ -333,7 +333,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -333,7 +333,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
**
extra_weight_attrs
):
"""
"""
Use the CompressedTensorsScheme associated with each layer to create
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
the necessary parameters for the layer. See LinearMethodBase for param
details
details
"""
"""
...
@@ -352,8 +352,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -352,8 +352,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
"""
Use the output of create_weights and the CompressedTensorsScheme
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
layer input. See LinearMethodBase for param details
"""
"""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
539aa992
...
@@ -5,10 +5,16 @@ from typing import Callable, List, Optional
...
@@ -5,10 +5,16 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
CompressionFormat
,
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
,
print_warning_once
class
GPTQMarlinState
(
Enum
):
class
GPTQMarlinState
(
Enum
):
...
@@ -16,11 +22,219 @@ class GPTQMarlinState(Enum):
...
@@ -16,11 +22,219 @@ class GPTQMarlinState(Enum):
READY
=
enum
.
auto
()
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
__all__
=
[
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
not
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
):
raise
ValueError
(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
def
__init__
(
self
,
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
...
@@ -38,10 +252,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -38,10 +252,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
not
(
self
.
quant_config
.
quant_format
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
==
4
):
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for 4 bits"
)
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
...
@@ -292,4 +507,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -292,4 +507,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
topk_ids
,
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
num_bits
=
self
.
num_bits
,
)
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
539aa992
...
@@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.utils
import
is_hip
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
...
@@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
if
is_hip
():
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# If channelwise, scales are already lined up, so just transpose.
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
else
:
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
requires_grad
=
False
)
else
:
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
539aa992
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
,
Set
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.kernels
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_repeat_scales_on_all_ranks
)
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
...
@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
...
@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter
)
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsWNA16"
]
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
4
:
scalar_types
.
uint4b8
,
...
@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
...
@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
Set
[
str
]
=
set
()
def
__init__
(
self
,
def
__init__
(
self
,
strategy
:
str
,
strategy
:
str
,
...
@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
# ampere and up
return
80
return
80
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
in
put_size
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
out
put_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
group_size
,
zero_points
=
False
,
has_g_idx
=
self
.
has_g_idx
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsWNA16"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# If group_size is -1, we are in channelwise case.
# If group_size is -1, we are in channelwise case.
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
row_parallel
=
(
input_size
!=
input_size_per_partition
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
if
partition_scales
:
...
@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
input_size_per_partition
=
input_size_per_partition
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
layer
.
output_size_per_partition
=
output_size_per_partition
w_q_param_name
=
"weight_packed"
,
layer
.
input_size
=
input_size
w_s_param_name
=
"weight_scale"
,
layer
.
group_size
=
group_size
w_zp_param_name
=
None
,
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
# Checkpoints are serialized in compressed-tensors format, which is
# different from
marlin forma
t. Handle repacking here.
# different from
the format the kernel may wan
t. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
weight_packed
.
device
self
.
kernel
.
process_weights_after_loading
(
layer
)
# Allocate marlin workspace.
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Handle sorting for activation reordering if needed.
if
self
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
weight_g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"weight_g_idx"
,
g_idx
)
else
:
layer
.
weight_g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
# Update for kernel
layer
.
weight_packed
=
torch
.
nn
.
Parameter
(
layer
.
weight_packed
.
t
().
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
requires_grad
=
False
)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_type
.
size_bits
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
,
size_k
=
(
layer
.
input_size
if
self
.
has_g_idx
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
return
apply_gptq_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
weight_g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
,
bias
=
bias
)
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
539aa992
...
@@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...
@@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
)
apply_fp8_linear
,
normalize_e4m3fn_to_e4m3fnuz
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -32,9 +33,7 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -32,9 +33,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
self
.
use_marlin
=
not
current_platform
.
has_device_capability
(
89
)
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -127,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -127,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
quant_config
.
use_marlin
:
if
self
.
quant_config
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
# Activations not quantized for marlin.
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment