Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
456 additions
and
532 deletions
+456
-532
vllm/lora/peft_helper.py
vllm/lora/peft_helper.py
+24
-4
vllm/lora/request.py
vllm/lora/request.py
+1
-0
vllm/lora/utils.py
vllm/lora/utils.py
+1
-1
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+3
-1
vllm/model_executor/guided_decoding/guidance_decoding.py
vllm/model_executor/guided_decoding/guidance_decoding.py
+1
-1
vllm/model_executor/guided_decoding/guidance_logits_processors.py
...el_executor/guided_decoding/guidance_logits_processors.py
+22
-4
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+1
-1
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+12
-0
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+1
-1
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+6
-4
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+3
-15
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+101
-149
vllm/model_executor/layers/fused_moe/moe_pallas.py
vllm/model_executor/layers/fused_moe/moe_pallas.py
+18
-2
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+4
-0
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+0
-1
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+123
-278
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+9
-4
vllm/model_executor/layers/mamba/mamba2_metadata.py
vllm/model_executor/layers/mamba/mamba2_metadata.py
+18
-5
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+106
-59
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
No files found.
vllm/lora/peft_helper.py
View file @
4eabe123
...
@@ -10,6 +10,7 @@ from typing import Literal, Optional, Union
...
@@ -10,6 +10,7 @@ from typing import Literal, Optional, Union
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -89,12 +90,31 @@ class PEFTHelper:
...
@@ -89,12 +90,31 @@ class PEFTHelper:
return
cls
(
**
filtered_dict
)
return
cls
(
**
filtered_dict
)
@
classmethod
@
classmethod
def
from_local_dir
(
cls
,
lora_path
:
str
,
def
from_local_dir
(
max_position_embeddings
:
Optional
[
int
])
->
"PEFTHelper"
:
cls
,
lora_path
:
str
,
max_position_embeddings
:
Optional
[
int
],
tensorizer_config_dict
:
Optional
[
dict
]
=
None
)
->
"PEFTHelper"
:
lora_config_path
=
os
.
path
.
join
(
lora_path
,
"adapter_config.json"
)
lora_config_path
=
os
.
path
.
join
(
lora_path
,
"adapter_config.json"
)
if
tensorizer_config_dict
:
tensorizer_config
=
TensorizerConfig
(
**
tensorizer_config_dict
)
tensorizer_args
=
tensorizer_config
.
_construct_tensorizer_args
()
from
tensorizer.stream_io
import
open_stream
lora_config_path
=
os
.
path
.
join
(
tensorizer_config
.
lora_dir
,
"adapter_config.json"
)
with
open_stream
(
lora_config_path
,
mode
=
"rb"
,
**
tensorizer_args
.
stream_params
)
as
f
:
config
=
json
.
load
(
f
)
logger
.
info
(
"Successfully deserialized LoRA config from %s"
,
tensorizer_config
.
lora_dir
)
else
:
with
open
(
lora_config_path
)
as
f
:
with
open
(
lora_config_path
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
config
[
"vllm_max_position_embeddings"
]
=
max_position_embeddings
config
[
"vllm_max_position_embeddings"
]
=
max_position_embeddings
return
cls
.
from_dict
(
config
)
return
cls
.
from_dict
(
config
)
...
...
vllm/lora/request.py
View file @
4eabe123
...
@@ -31,6 +31,7 @@ class LoRARequest(
...
@@ -31,6 +31,7 @@ class LoRARequest(
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
long_lora_max_len
:
Optional
[
int
]
=
None
base_model_name
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
base_model_name
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
tensorizer_config_dict
:
Optional
[
dict
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
lora_local_path
:
if
self
.
lora_local_path
:
...
...
vllm/lora/utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
import
re
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
huggingface_hub
import
huggingface_hub
import
regex
as
re
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
HFValidationError
,
RepositoryNotFoundError
)
HFValidationError
,
RepositoryNotFoundError
)
from
torch
import
nn
from
torch
import
nn
...
...
vllm/lora/worker_manager.py
View file @
4eabe123
...
@@ -100,7 +100,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -100,7 +100,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
peft_helper
=
PEFTHelper
.
from_local_dir
(
peft_helper
=
PEFTHelper
.
from_local_dir
(
lora_path
,
self
.
max_position_embeddings
)
lora_path
,
self
.
max_position_embeddings
,
lora_request
.
tensorizer_config_dict
)
# Validates the LoRA configuration against requirements before
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
# loading weights, throwing an exception if validation fails.
...
@@ -125,6 +126,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -125,6 +126,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self
.
lora_config
.
lora_extra_vocab_size
,
self
.
lora_config
.
lora_extra_vocab_size
,
embedding_modules
=
self
.
embedding_modules
,
embedding_modules
=
self
.
embedding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
tensorizer_config_dict
=
lora_request
.
tensorizer_config_dict
,
weights_mapper
=
hf_to_vllm_mapper
)
weights_mapper
=
hf_to_vllm_mapper
)
except
FileNotFoundError
as
e
:
except
FileNotFoundError
as
e
:
...
...
vllm/model_executor/guided_decoding/guidance_decoding.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
from
re
import
escape
as
regex_escape
import
llguidance
import
llguidance
from
regex
import
escape
as
regex_escape
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm.model_executor.guided_decoding.guidance_logits_processors
import
(
from
vllm.model_executor.guided_decoding.guidance_logits_processors
import
(
...
...
vllm/model_executor/guided_decoding/guidance_logits_processors.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
os
import
os
from
typing
import
Any
from
typing
import
Any
...
@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
...
@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
self
.
grammar
=
grammar
self
.
grammar
=
grammar
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
tokenizer_name
=
tokenizer
.
name_or_path
self
.
tokenizer_name
=
tokenizer
.
name_or_path
self
.
ll_tokenizer
=
None
self
.
ll_matcher
=
None
self
.
bitmask
=
None
self
.
new_sampling
=
False
self
.
new_sampling
=
False
self
.
initialized
=
False
self
.
initialized
=
False
def
clone
(
self
)
->
"GuidanceLogitsProcessor"
:
cloned
=
copy
.
copy
(
self
)
if
self
.
initialized
:
cloned
.
ll_matcher
=
llguidance
.
LLMatcher
(
self
.
ll_tokenizer
,
# type: ignore[assignment]
self
.
grammar
,
log_level
=
int
(
os
.
environ
.
get
(
"LLGUIDANCE_LOG_LEVEL"
,
"1"
)),
)
self
.
bitmask
=
llguidance
.
torch
.
allocate_token_bitmask
(
1
,
self
.
ll_tokenizer
.
vocab_size
)
# type: ignore[attr-defined]
return
cloned
def
_initialize
(
self
):
def
_initialize
(
self
):
if
self
.
initialized
:
if
self
.
initialized
:
return
return
...
@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
...
@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
# create reusable bitmask
# create reusable bitmask
self
.
bitmask
=
llguidance
.
torch
.
allocate_token_bitmask
(
self
.
bitmask
=
llguidance
.
torch
.
allocate_token_bitmask
(
1
,
self
.
ll_tokenizer
.
vocab_size
)
1
,
self
.
ll_tokenizer
.
vocab_size
)
# type: ignore[attr-defined]
self
.
initialized
=
True
self
.
initialized
=
True
...
@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
...
@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
self
.
_initialize
()
self
.
_initialize
()
if
self
.
new_sampling
and
len
(
input_ids
)
>
0
:
if
self
.
new_sampling
and
len
(
input_ids
)
>
0
:
self
.
ll_matcher
.
consume_token
(
input_ids
[
-
1
])
self
.
ll_matcher
.
consume_token
(
# type: ignore[attr-defined]
err
=
self
.
ll_matcher
.
get_error
()
input_ids
[
-
1
])
err
=
self
.
ll_matcher
.
get_error
()
# type: ignore[attr-defined]
if
err
:
if
err
:
logger
.
warning
(
"Error in LLMatcher: %s"
,
err
)
logger
.
warning
(
"Error in LLMatcher: %s"
,
err
)
llguidance
.
torch
.
fill_next_token_bitmask
(
self
.
ll_matcher
,
self
.
bitmask
,
llguidance
.
torch
.
fill_next_token_bitmask
(
self
.
ll_matcher
,
self
.
bitmask
,
0
)
0
)
llguidance
.
torch
.
apply_token_bitmask_inplace
(
llguidance
.
torch
.
apply_token_bitmask_inplace
(
scores
,
self
.
bitmask
.
to
(
scores
.
device
))
scores
,
self
.
bitmask
.
to
(
scores
.
device
))
# type: ignore[attr-defined]
self
.
new_sampling
=
True
self
.
new_sampling
=
True
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
4eabe123
...
@@ -5,9 +5,9 @@ import concurrent.futures
...
@@ -5,9 +5,9 @@ import concurrent.futures
import
os
import
os
from
enum
import
Enum
from
enum
import
Enum
from
json
import
dumps
as
json_dumps
from
json
import
dumps
as
json_dumps
from
re
import
escape
as
regex_escape
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
regex
import
escape
as
regex_escape
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
4eabe123
...
@@ -56,6 +56,12 @@ class BaseLogitsProcessor:
...
@@ -56,6 +56,12 @@ class BaseLogitsProcessor:
self
.
_fsm_state
:
defaultdict
[
int
,
Union
[
int
,
self
.
_fsm_state
:
defaultdict
[
int
,
Union
[
int
,
CFGState
]]
=
defaultdict
(
int
)
CFGState
]]
=
defaultdict
(
int
)
def
clone
(
self
)
->
"BaseLogitsProcessor"
:
cloned
=
copy
.
copy
(
self
)
cloned
.
_guide
=
self
.
_guide
.
copy
()
cloned
.
_fsm_state
=
copy
.
deepcopy
(
self
.
_fsm_state
)
return
cloned
def
__call__
(
self
,
input_ids
:
list
[
int
],
def
__call__
(
self
,
input_ids
:
list
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
"""Use the FSM to bias the logits before sampling the next token."""
...
@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
...
@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
reasoner
)
reasoner
)
self
.
_guide
=
self
.
_guide
.
copy
()
self
.
_guide
=
self
.
_guide
.
copy
()
def
clone
(
self
)
->
"CFGLogitsProcessor"
:
cloned
=
copy
.
copy
(
self
)
cloned
.
_fsm_state
=
copy
.
deepcopy
(
self
.
_fsm_state
)
cloned
.
_guide
=
self
.
_guide
.
copy
()
return
cloned
@
lru_cache
(
maxsize
=
32
)
@
lru_cache
(
maxsize
=
32
)
def
_adapt_tokenizer
(
tokenizer
:
PreTrainedTokenizerBase
):
def
_adapt_tokenizer
(
tokenizer
:
PreTrainedTokenizerBase
):
...
...
vllm/model_executor/guided_decoding/utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
re
import
re
gex
as
re
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
4eabe123
...
@@ -4,10 +4,10 @@
...
@@ -4,10 +4,10 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
json
import
json
import
re
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
import
regex
as
re
import
torch
import
torch
import
vllm.envs
import
vllm.envs
...
@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor:
...
@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor:
prefilled
:
bool
=
field
(
default
=
False
)
prefilled
:
bool
=
field
(
default
=
False
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_info
is
None
:
self
.
tokenizer_info
=
self
.
config
.
tokenizer_info
(
self
.
tokenizer_info
=
self
.
config
.
tokenizer_info
(
self
.
config
.
tokenizer_data
)
self
.
config
.
tokenizer_data
)
...
@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
...
@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
def
clone
(
self
)
->
XGrammarLogitsProcessor
:
def
clone
(
self
)
->
XGrammarLogitsProcessor
:
"""Create a new instance with shared compiled grammar
"""Create a new instance with shared compiled grammar
but separate state"""
but separate state"""
new_processor
=
XGrammarLogitsProcessor
(
self
.
config
,
self
.
reasoner
)
new_processor
=
XGrammarLogitsProcessor
(
self
.
config
,
self
.
reasoner
,
None
,
self
.
tokenizer_info
)
# Share the compiled grammar context (immutable after compilation)
# Share the compiled grammar context (immutable after compilation)
new_processor
.
ctx
=
self
.
ctx
new_processor
.
ctx
=
self
.
ctx
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
""" CUTLASS based Fused MoE kernels."""
""" CUTLASS based Fused MoE kernels."""
import
os
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
...
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
MAX_TOKENS_PER_EXPERT
=
int
(
os
.
environ
.
get
(
'VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT'
,
'65536'
))
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
...
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
...
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
assert
(
topk_weights
.
shape
[
0
]
==
m
and
topk_ids
.
shape
[
0
]
assert
(
topk_weights
.
shape
[
0
]
==
m
and
topk_ids
.
shape
[
0
]
==
m
),
(
"topk must be provided for each row of a"
)
==
m
),
(
"topk must be provided for each row of a"
)
assert
(
m
<=
MAX_TOKENS_PER_EXPERT
),
(
f
"m must be less than MAX_TOKENS_PER_EXPERT(
{
MAX_TOKENS_PER_EXPERT
}
)"
f
" for cutlass_moe_fp4, observed m =
{
m
}
. Use"
f
" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
out_dtype
=
a
.
dtype
out_dtype
=
a
.
dtype
num_topk
=
topk_ids
.
shape
[
1
]
num_topk
=
topk_ids
.
shape
[
1
]
...
@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
...
@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
expert_offsets
,
expert_offsets
,
blockscale_offsets
,
blockscale_offsets
,
num_topk
,
num_topk
,
expert_map
=
a_map
,
expert_map
=
a_map
)
MAX_TOKENS_PER_EXPERT
=
MAX_TOKENS_PER_EXPERT
)
c1
=
ops
.
cutlass_fp4_moe_mm
(
rep_a_fp4
,
w1_fp4
,
rep_a_blockscale
,
c1
=
ops
.
cutlass_fp4_moe_mm
(
rep_a_fp4
,
w1_fp4
,
rep_a_blockscale
,
w1_blockscale
,
w1_alphas
,
problem_sizes1
,
w1_blockscale
,
w1_alphas
,
problem_sizes1
,
...
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
...
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
int_fp4
,
int_blockscale
=
ops
.
scaled_fp4_experts_quant
(
int_fp4
,
int_blockscale
=
ops
.
scaled_fp4_experts_quant
(
intermediate
,
intermediate
,
a2_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
)
a2_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
,
MAX_TOKENS_PER_EXPERT
=
MAX_TOKENS_PER_EXPERT
)
c2
=
ops
.
cutlass_fp4_moe_mm
(
int_fp4
,
w2_fp4
,
int_blockscale
,
w2_blockscale
,
c2
=
ops
.
cutlass_fp4_moe_mm
(
int_fp4
,
w2_fp4
,
int_blockscale
,
w2_blockscale
,
w2_alphas
,
problem_sizes2
,
expert_offsets
[:
-
1
],
w2_alphas
,
problem_sizes2
,
expert_offsets
[:
-
1
],
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
4eabe123
...
@@ -2,13 +2,11 @@
...
@@ -2,13 +2,11 @@
import
os
import
os
import
importlib
import
importlib
import
threading
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
from
weakref
import
WeakValueDictionary
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -45,6 +43,7 @@ if current_platform.is_cuda_alike():
...
@@ -45,6 +43,7 @@ if current_platform.is_cuda_alike():
from
.pplx_prepare_finalize
import
PplxPrepareAndFinalize
from
.pplx_prepare_finalize
import
PplxPrepareAndFinalize
else
:
else
:
fused_experts
=
None
# type: ignore
fused_experts
=
None
# type: ignore
FusedMoEPermuteExpertsUnpermute
=
None
# type: ignore
FusedMoEPrepareAndFinalize
=
None
# type: ignore
FusedMoEPrepareAndFinalize
=
None
# type: ignore
if
is_rocm_aiter_moe_enabled
():
if
is_rocm_aiter_moe_enabled
():
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
...
@@ -52,8 +51,7 @@ if is_rocm_aiter_moe_enabled():
...
@@ -52,8 +51,7 @@ if is_rocm_aiter_moe_enabled():
else
:
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
# the iterative moe implementation is used until the moe_pallas is fixed
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
from
.moe_torch_iterative
import
fused_moe
as
fused_moe_pallas
else
:
else
:
fused_moe_pallas
=
None
# type: ignore
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -76,7 +74,8 @@ class FusedMoEParallelConfig:
...
@@ -76,7 +74,8 @@ class FusedMoEParallelConfig:
@
property
@
property
def
use_pplx_kernels
(
self
):
def
use_pplx_kernels
(
self
):
return
self
.
dp_size
>
1
and
self
.
use_ep
and
has_pplx
return
self
.
dp_size
>
1
and
self
.
use_ep
and
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"pplx"
@
staticmethod
@
staticmethod
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
...
@@ -199,6 +198,8 @@ class MoEConfig:
...
@@ -199,6 +198,8 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc.
# TODO: add more quantization params, blocked, per-token, etc.
block_size
:
int
=
128
block_size
:
int
=
128
max_num_tokens
:
int
=
MOE_DP_CHUNK_SIZE
@
property
@
property
def
tp_size
(
self
):
def
tp_size
(
self
):
return
self
.
moe_parallel_config
.
tp_size
return
self
.
moe_parallel_config
.
tp_size
...
@@ -247,13 +248,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -247,13 +248,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
raise
NotImplementedError
def
set_prepare_finalize
(
def
init_prepare_finalize
(
self
,
moe
:
MoEConfig
,
self
,
quant_config
:
Optional
[
QuantizationConfig
]):
dp_size
:
int
,
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
world_size
:
int
,
assert
all2all_manager
is
not
None
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
)
->
bool
:
prepare_finalize
=
None
return
False
if
moe
.
use_pplx_kernels
:
all_to_all_args
=
dict
(
max_num_tokens
=
moe
.
max_num_tokens
,
num_experts
=
moe
.
num_experts
,
experts_per_token
=
moe
.
experts_per_token
,
# topk
rank
=
all2all_manager
.
rank
,
world_size
=
all2all_manager
.
world_size
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
hidden_dim
=
moe
.
hidden_dim
,
hidden_dim_bytes
=
moe
.
hidden_dim
*
moe
.
in_dtype
.
itemsize
,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes
=
(
0
if
moe
.
in_dtype
.
itemsize
!=
1
else
(
(
moe
.
hidden_dim
+
moe
.
block_size
-
1
)
//
moe
.
block_size
*
torch
.
float32
.
itemsize
)),
group_name
=
all2all_manager
.
cpu_group
.
group_name
,
)
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
PplxPrepareAndFinalize
(
handle
,
max_num_tokens
=
moe
.
max_num_tokens
,
world_size
=
all2all_manager
.
world_size
,
rank
=
all2all_manager
.
rank
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
quant_dtype
=
moe
.
in_dtype
,
)
if
prepare_finalize
is
not
None
:
experts
=
self
.
select_gemm_impl
(
prepare_finalize
)
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
Optional
[
FusedMoEPrepareAndFinalize
]
)
->
FusedMoEPermuteExpertsUnpermute
:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise
NotImplementedError
(
"Subclass must select appropriate gemm implementation"
" based on the prepare_finalize"
)
@
abstractmethod
@
abstractmethod
def
apply
(
def
apply
(
...
@@ -277,53 +324,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -277,53 +324,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
raise
NotImplementedError
class
AllToAllCache
:
def
__init__
(
self
):
self
.
_cache
:
WeakValueDictionary
=
WeakValueDictionary
()
self
.
_lock
=
threading
.
RLock
()
# Reentrant lock for thread safety
def
destroy
(
self
):
with
self
.
_lock
:
# TODO: can we do del self._cache?
for
_
,
a2a
in
self
.
_cache
.
items
():
a2a
.
destroy
()
def
get_or_create
(
self
,
**
kwargs
):
assert
has_pplx
import
pplx_kernels
as
pplx
# Create a hashable key from the kwargs
key
=
tuple
(
sorted
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()))
with
self
.
_lock
:
instance
=
self
.
_cache
.
get
(
key
)
if
instance
is
None
:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger
.
debug
(
"Create AllToAll %s"
,
kwargs
)
instance
=
pplx
.
AllToAll
.
internode
(
**
kwargs
)
self
.
_cache
[
key
]
=
instance
return
instance
# Global singleton
_all_to_all_cache
=
AllToAllCache
()
# Factory function as a cleaner interface
def
get_all_to_all
(
**
kwargs
):
return
_all_to_all_cache
.
get_or_create
(
**
kwargs
)
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
moe
:
MoEConfig
):
def
__init__
(
self
,
moe
:
MoEConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
fused_experts
=
fused_experts
self
.
fused_experts
=
fused_experts
# type: ignore
self
.
moe
=
moe
self
.
moe
=
moe
self
.
rocm_aiter_moe_enabled
=
is_rocm_aiter_moe_enabled
()
self
.
rocm_aiter_moe_enabled
=
is_rocm_aiter_moe_enabled
()
...
@@ -333,6 +340,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -333,6 +340,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else
:
else
:
self
.
rocm_aiter_fused_experts
=
None
# type: ignore
self
.
rocm_aiter_fused_experts
=
None
# type: ignore
def
select_gemm_impl
(
self
,
prepare_finalize
:
Optional
[
FusedMoEPrepareAndFinalize
]):
assert
self
.
fused_experts
==
fused_experts
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
experts
:
Optional
[
FusedMoEPermuteExpertsUnpermute
]
=
None
if
isinstance
(
prepare_finalize
,
(
BatchedPrepareAndFinalize
,
PplxPrepareAndFinalize
)):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
experts
=
BatchedTritonExperts
(
max_num_tokens
=
MOE_DP_CHUNK_SIZE
,
world_size
=
all2all_manager
.
world_size
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
block_shape
=
None
,
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
experts
=
TritonExperts
(
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
block_shape
=
None
,
per_channel_quant
=
False
,
)
return
experts
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
use_nn_moe
:
bool
,
params_dtype
:
torch
.
dtype
,
use_nn_moe
:
bool
,
...
@@ -392,10 +435,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -392,10 +435,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffle_weights
)
shuffle_weights
)
if
self
.
rocm_aiter_moe_enabled
:
if
self
.
rocm_aiter_moe_enabled
:
# use 2stage ck moe layout
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w2_weight
.
data
,
layout
=
(
32
,
32
))
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
layer
.
w2_weight
.
data
=
shuffled_w2
...
@@ -448,47 +489,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -448,47 +489,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
def
set_prepare_finalize
(
self
,
dp_size
:
int
,
world_size
:
int
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
)
->
bool
:
assert
self
.
fused_experts
==
fused_experts
experts
:
Optional
[
FusedMoEPermuteExpertsUnpermute
]
=
None
if
isinstance
(
prepare_finalize
,
(
BatchedPrepareAndFinalize
,
PplxPrepareAndFinalize
)):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
experts
=
BatchedTritonExperts
(
max_num_tokens
=
MOE_DP_CHUNK_SIZE
,
world_size
=
world_size
,
dp_size
=
dp_size
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
block_shape
=
None
,
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
experts
=
TritonExperts
(
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
block_shape
=
None
,
per_channel_quant
=
False
,
)
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
return
True
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -702,45 +702,6 @@ def determine_expert_map(
...
@@ -702,45 +702,6 @@ def determine_expert_map(
return
(
local_num_experts
,
expert_map
)
return
(
local_num_experts
,
expert_map
)
def
_construct_prepare_finalize
(
moe
:
MoEConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
)
->
Optional
[
FusedMoEPrepareAndFinalize
]:
max_num_tokens
=
MOE_DP_CHUNK_SIZE
world_size
=
moe
.
ep_size
dp_size
=
moe
.
ep_size
//
moe
.
dp_size
# dp_size actually means TP.
rank
=
moe
.
ep_rank
if
moe
.
use_pplx_kernels
:
logger
.
debug
(
"using PplxPrepareAndFinalize"
)
all_to_all
=
get_all_to_all
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
moe
.
num_experts
,
experts_per_token
=
moe
.
experts_per_token
,
# topk
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
moe
.
hidden_dim
,
hidden_dim_bytes
=
moe
.
hidden_dim
*
moe
.
in_dtype
.
itemsize
,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes
=
(
0
if
moe
.
in_dtype
.
itemsize
!=
1
else
((
moe
.
hidden_dim
+
moe
.
block_size
-
1
)
//
moe
.
block_size
*
torch
.
float32
.
itemsize
)))
return
PplxPrepareAndFinalize
(
all_to_all
,
max_num_tokens
=
max_num_tokens
,
world_size
=
world_size
,
rank
=
rank
,
dp_size
=
dp_size
,
quant_dtype
=
moe
.
in_dtype
,
)
return
None
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
...
@@ -854,7 +815,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -854,7 +815,10 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config
=
self
.
moe_parallel_config
,
moe_parallel_config
=
self
.
moe_parallel_config
,
# TODO (bnell): this needs to be fixed for quantized types.
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype
=
params_dtype
,
in_dtype
=
params_dtype
,
max_num_tokens
=
MOE_DP_CHUNK_SIZE
,
)
)
self
.
moe_config
=
moe
self
.
quant_config
=
quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
# for heuristic purposes, so it must be initialized first.
...
@@ -862,25 +826,13 @@ class FusedMoE(torch.nn.Module):
...
@@ -862,25 +826,13 @@ class FusedMoE(torch.nn.Module):
if
quant_config
is
None
:
if
quant_config
is
None
:
quant_method
=
UnquantizedFusedMoEMethod
(
moe
)
quant_method
=
UnquantizedFusedMoEMethod
(
moe
)
prepare_finalize
=
_construct_prepare_finalize
(
moe
,
quant_config
)
else
:
else
:
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
# No pplx for quantized types yet.
prepare_finalize
=
None
assert
quant_method
is
not
None
assert
quant_method
is
not
None
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
self
.
quant_method
=
quant_method
self
.
quant_method
=
quant_method
if
prepare_finalize
is
not
None
:
world_size
=
moe
.
ep_size
dp_size
=
int
(
moe
.
ep_size
//
moe
.
dp_size
)
success
=
self
.
quant_method
.
set_prepare_finalize
(
dp_size
,
world_size
,
prepare_finalize
)
if
not
success
:
logger
.
warning
(
"DP+EP not supported for %s."
,
type
(
self
.
quant_method
))
if
quant_config
is
None
:
if
quant_config
is
None
:
# Not considering quant for now, temporarily
# Not considering quant for now, temporarily
# self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
# self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
...
...
vllm/model_executor/layers/fused_moe/moe_pallas.py
View file @
4eabe123
...
@@ -2,7 +2,23 @@
...
@@ -2,7 +2,23 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch_xla.experimental.custom_kernel
import
_histogram
def
_histogram
(
input
:
torch
.
Tensor
,
min
:
int
,
max
:
int
)
->
torch
.
Tensor
:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert
input
.
dtype
==
torch
.
int32
,
"input must be of torch.int32 dtype."
assert
min
<=
max
,
"min must be less than or equal to max."
def
searchsorted
(
sorted_sequence
:
torch
.
Tensor
,
values_to_search
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
sorted_sequence
.
unsqueeze
(
1
)
==
values_to_search
).
sum
(
dim
=
1
)
bin_edges
=
torch
.
linspace
(
min
,
max
,
max
-
min
+
1
,
dtype
=
input
.
dtype
).
to
(
input
.
device
)
return
searchsorted
(
bin_edges
,
input
).
to
(
torch
.
int32
)
def
fused_moe
(
def
fused_moe
(
...
@@ -61,7 +77,7 @@ def fused_moe(
...
@@ -61,7 +77,7 @@ def fused_moe(
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
*
topk_weights
.
unsqueeze
_
(
dim
=-
1
)
x
=
x
*
topk_weights
.
unsqueeze
(
dim
=-
1
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
reshape
(
orig_shape
)
x
=
x
.
reshape
(
orig_shape
)
return
x
return
x
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
4eabe123
...
@@ -182,3 +182,7 @@ def moe_unpermute(
...
@@ -182,3 +182,7 @@ def moe_unpermute(
expert_first_token_offset
,
n_expert
,
expert_first_token_offset
,
n_expert
,
n_local_expert
,
topk
,
hidden_states
)
n_local_expert
,
topk
,
hidden_states
)
return
hidden_states
return
hidden_states
def
moe_permute_unpermute_supported
():
return
torch
.
ops
.
_moe_C
.
moe_permute_unpermute_supported
()
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
4eabe123
...
@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
...
@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input
)
moe_kernel_quantize_input
)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
# as the ones used to create the AllToAll.
class
PplxPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
class
PplxPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
enum
import
IntEnum
from
functools
import
cache
from
functools
import
cache
from
typing
import
Optional
from
typing
import
Optional
...
@@ -9,6 +10,28 @@ from vllm.platforms import current_platform
...
@@ -9,6 +10,28 @@ from vllm.platforms import current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
class
QuantMethod
(
IntEnum
):
# This allows interfacing with AITER QuantType Enum
# without importing the QuantType from AITER globally.
# Note that these quantization methods are
# supported in AITER package. However,
# not all are used in this module.
NO
=
0
# a16w16
PER_TENSOR
=
1
# w8a8 (pre_Tensor)
PER_TOKEN
=
2
# w8a8/w8a4 (per_Token)
BLOCK_1X128
=
3
# block quantized w8a8 (per_1x128)
BLOCK_128x128
=
4
# block quantized w8a8 (per_128x128)
class
ActivationMethod
(
IntEnum
):
# This allows interfacing with AITER ActivationType enum
# without importing the ActivationType enum from AITER globally.
SILU
=
0
GELU
=
1
@
cache
@
cache
def
is_rocm_aiter_moe_enabled
()
->
bool
:
def
is_rocm_aiter_moe_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
return
current_platform
.
is_rocm
()
\
...
@@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl(
...
@@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl(
a16
:
bool
=
False
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_
str
:
str
=
"si
lu
"
)
->
torch
.
Tensor
:
activation_
method
:
int
=
ActivationMethod
.
SILU
.
va
lu
e
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe_tkw1
from
aiter.fused_moe_bf16_asm
import
asm_moe_tkw1
activation
=
\
activation
=
ActivationType
(
activation_method
)
ActivationType
.
Gelu
if
activation_str
==
"gelu"
else
ActivationType
.
Silu
return
asm_moe_tkw1
(
hidden_states
,
return
asm_moe_tkw1
(
hidden_states
,
w1
,
w1
,
...
@@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
...
@@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
a16
:
bool
=
False
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_str
:
str
=
"silu"
)
->
torch
.
Tensor
:
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_impl
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
list
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
from
aiter
import
fmoe_fp8_blockscale_g1u1
from
aiter.fused_moe_bf16_asm
import
moe_sorting_ck
topk
=
topk_ids
.
shape
[
1
]
model_dim
=
w1
.
shape
[
-
1
]
local_E
=
E
=
w1
.
shape
[
0
]
if
expert_mask
is
not
None
:
E
=
expert_mask
.
numel
()
(
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
out_asm
,
)
=
moe_sorting_ck
(
topk_ids
,
topk_weights
,
E
,
model_dim
,
hidden_states_dtype
,
expert_mask
=
expert_mask
)
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
a1_scale
.
t
().
contiguous
(),
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
*
block_shape
,
smooth_scale
)
return
out_asm
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
list
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a1
,
dtype
=
hidden_states_dtype
)
def
rocm_aiter_asm_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
import
aiter.fused_moe_bf16_asm
as
rocm_aiter_asm_fmoe
from
aiter
import
ActivationType
assert
activation
in
[
"silu"
,
"gelu"
],
"The given activation:"
\
f
"
{
activation
}
"
\
" is not supported in"
\
" AITER."
if
activation
==
"silu"
:
aiter_activation
=
ActivationType
.
Silu
else
:
aiter_activation
=
ActivationType
.
Gelu
return
rocm_aiter_asm_fmoe
.
asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
fc1_smooth_scale
=
fc1_smooth_scale
,
fc2_smooth_scale
=
fc2_smooth_scale
,
a16
=
a16
,
activation
=
aiter_activation
)
def
rocm_aiter_asm_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_ck_moe_2stages_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_size
:
Optional
[
list
[
int
]]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
aiter.fused_moe_bf16_asm
import
ck_moe_2stages
return
ck_moe_2stages
(
a1
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_size
=
block_size
,
expert_mask
=
expert_mask
)
def
rocm_aiter_ck_moe_2stages_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_size
:
Optional
[
list
[
int
]]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake(
...
@@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake(
pass
pass
def
rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
quant_method
:
int
=
QuantMethod
.
NO
.
value
,
doweight_stage1
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
activation
=
ActivationType
(
activation_method
)
quant_type
=
QuantType
(
quant_method
)
return
fused_moe
(
hidden_states
,
w1
,
w2
,
topk_weight
,
topk_ids
,
expert_mask
,
activation
,
quant_type
,
doweight_stage1
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
def
rocm_aiter_fused_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
quant_method
:
int
=
QuantMethod
.
NO
.
value
,
doweight_stage1
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
direct_register_custom_op
(
...
@@ -285,26 +195,10 @@ if current_platform.is_rocm():
...
@@ -285,26 +195,10 @@ if current_platform.is_rocm():
)
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_f
moe_fp8_blockscale_g1u1
"
,
op_name
=
"rocm_aiter_f
used_moe
"
,
op_func
=
rocm_aiter_f
moe_fp8_blockscale_g1u1
_impl
,
op_func
=
rocm_aiter_f
used_moe
_impl
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
,
fake_impl
=
rocm_aiter_fused_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe"
,
op_func
=
rocm_aiter_asm_moe_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_asm_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_ck_moe_2stages"
,
op_func
=
rocm_aiter_ck_moe_2stages_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_ck_moe_2stages_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
...
@@ -373,32 +267,14 @@ def rocm_aiter_fused_experts(
...
@@ -373,32 +267,14 @@ def rocm_aiter_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
activation_method
=
(
ActivationMethod
.
SILU
per_token_group_quant_fp8
)
if
activation
==
"silu"
else
ActivationMethod
.
GELU
)
# All AITER Fused MoE kernels are expecting the following datatypes
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
# w8a8 block-scaled
if
block_shape
is
not
None
and
use_fp8_w8a8
:
assert
not
apply_router_weight_on_input
,
(
"apply_router_weight_on_input is not supported for block scaled moe"
)
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# The default block sizes are 128 in AITER.
block_shape
=
[
128
,
128
]
if
block_shape
is
None
else
block_shape
a1
,
a1_scale
=
per_token_group_quant_fp8
(
hidden_states
,
block_shape
[
1
])
return
torch
.
ops
.
vllm
.
rocm_aiter_fmoe_fp8_blockscale_g1u1
(
topk_ids
,
topk_weights
,
hidden_states
.
dtype
,
None
,
a1
,
w1
,
w2
,
w1_scale
,
w2_scale
,
a1_scale
,
block_shape
,
None
)
# w8a8 per-channel quantization
# w8a8 per-channel quantization
el
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
# rather than the second FC.
...
@@ -421,42 +297,23 @@ def rocm_aiter_fused_experts(
...
@@ -421,42 +297,23 @@ def rocm_aiter_fused_experts(
a16
=
False
,
a16
=
False
,
per_tensor_quant_scale
=
None
,
per_tensor_quant_scale
=
None
,
expert_mask
=
None
,
expert_mask
=
None
,
activation_
str
=
activation
)
activation_
method
=
activation
_method
)
# w8a8 per-tensor activation per-tensor weight
else
:
elif
use_fp8_w8a8
:
quant_method
=
QuantMethod
.
NO
.
value
# w8a8 block-scaled
if
block_shape
is
not
None
and
use_fp8_w8a8
:
assert
not
apply_router_weight_on_input
,
(
assert
not
apply_router_weight_on_input
,
(
"apply_router_weight_on_input is not supported for fp8_w8a8"
)
"apply_router_weight_on_input is
\
not supported for block scaled moe"
)
# - faster static per-tensor-activation static per-tensor-weight
assert
w1_scale
is
not
None
# fp8 quantization w8a8
assert
w2_scale
is
not
None
if
a1_scale
is
not
None
and
a2_scale
is
not
None
:
quant_method
=
QuantMethod
.
BLOCK_128x128
.
value
return
torch
.
ops
.
vllm
.
rocm_aiter_ck_moe_2stages
(
elif
use_fp8_w8a8
:
hidden_states
=
hidden_states
,
# Currently only per tensor quantization method is enabled.
w1
=
w1
,
quant_method
=
QuantMethod
.
PER_TENSOR
.
value
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return
torch
.
ops
.
vllm
.
rocm_aiter_asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
fc1_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
a16
=
False
,
activation
=
activation
)
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
assert
(
topk_weights
.
dim
()
==
2
assert
(
topk_weights
.
dim
()
==
2
),
"`topk_weights` should be in shape (num_tokens, topk)"
),
"`topk_weights` should be in shape (num_tokens, topk)"
...
@@ -465,16 +322,19 @@ def rocm_aiter_fused_experts(
...
@@ -465,16 +322,19 @@ def rocm_aiter_fused_experts(
topk
==
1
topk
==
1
),
"Only support topk=1 when `apply_router_weight_on_input` is True"
),
"Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states
=
hidden_states
*
topk_weights
.
to
(
hidden_states
.
dtype
)
return
torch
.
ops
.
vllm
.
rocm_aiter_fused_moe
(
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
hidden_states
,
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
)
w1
,
w2
,
return
torch
.
ops
.
vllm
.
rocm_aiter_ck_moe_2stages
(
topk_weights
,
hidden_states
=
hidden_states
,
topk_ids
,
w1
=
w1
,
quant_method
=
quant_method
,
w2
=
w2
,
activation_method
=
activation_method
,
topk_weights
=
topk_weights
,
w1_scale
=
w1_scale
,
topk_ids
=
topk_ids
)
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
doweight_stage1
=
apply_router_weight_on_input
)
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
...
@@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
...
@@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return
topk_weights
,
topk_indices
return
topk_weights
,
topk_indices
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
,
def
shuffle_weights
(
layout
:
tuple
[
int
,
int
])
->
tuple
[
torch
.
Tensor
,
...]:
*
tensors
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
]
=
(
16
,
16
)
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
"""
Applies shuffle_weight function from AITER to each
Applies shuffle_weight function from AITER to each
input tensor and returns them.
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
Args:
*tensors: Variable number of torch.Tensor objects.
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
Returns:
Returns:
A Tuple of shuffled tensors.
A Tuple of shuffled tensors.
...
@@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor,
...
@@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor,
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
return
tuple
(
shuffle_weight
(
tensor
,
layout
=
layout
)
for
tensor
in
tensors
)
return
tuple
(
shuffle_weight
(
tensor
,
layout
=
layout
)
for
tensor
in
tensors
)
def
expand_weights
(
*
tensors
:
torch
.
Tensor
,
expansion_dims
:
list
[
int
])
->
tuple
[
torch
.
Tensor
,
...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A Tuple of tensors with expanded dimensions.
"""
assert
len
(
tensors
)
==
len
(
expansion_dims
),
\
"Number of tensors must match the number of expansion dimensions."
return
tuple
(
tensor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
((
-
1
,
dim
,
-
1
))
for
tensor
,
dim
in
zip
(
tensors
,
expansion_dims
))
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
4eabe123
...
@@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -585,8 +587,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -585,8 +587,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
2
:
self
.
qweight
=
param
.
materialize_nested
()
return
return
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -805,6 +805,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -805,6 +805,7 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -979,8 +980,6 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -979,8 +980,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
3
:
self
.
qweight
=
param
.
materialize_nested
()
return
return
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -1155,7 +1154,13 @@ class RowParallelLinear(LinearBase):
...
@@ -1155,7 +1154,13 @@ class RowParallelLinear(LinearBase):
bias can be fused with other element-wise operations.
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
reduce_results: If true, call all-reduce on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y = X_iA_i
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/layers/mamba/mamba2_metadata.py
View file @
4eabe123
...
@@ -5,10 +5,9 @@ from dataclasses import dataclass
...
@@ -5,10 +5,9 @@ from dataclasses import dataclass
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.placeholder_attn
import
(
from
vllm.attention.backends.placeholder_attn
import
(
PlaceholderAttentionMetadata
)
PlaceholderAttentionMetadata
)
from
vllm.at
tention.backends.x
form
er
s
import
XFormersMetadata
from
vllm.
pl
atforms
import
current_platform
@
dataclass
@
dataclass
...
@@ -23,6 +22,21 @@ class Mamba2Metadata:
...
@@ -23,6 +22,21 @@ class Mamba2Metadata:
chunk_offsets
:
torch
.
Tensor
chunk_offsets
:
torch
.
Tensor
def
get_platform_metadata_classes
()
->
tuple
[
type
[
AttentionMetadata
],
...]:
"""Returns the appropriate metadata classes for the current platform."""
if
current_platform
.
is_rocm
():
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
)
return
(
ROCmFlashAttentionMetadata
,
PlaceholderAttentionMetadata
)
elif
current_platform
.
is_cuda
():
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.xformers
import
XFormersMetadata
return
(
FlashAttentionMetadata
,
XFormersMetadata
,
PlaceholderAttentionMetadata
)
raise
ValueError
(
f
"Unsupported platform for Mamba2:
{
current_platform
.
device_type
}
"
)
def
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc
:
torch
.
Tensor
,
def
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
total_seqlens
:
int
):
total_seqlens
:
int
):
...
@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
...
@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if
num_prefills
>
0
:
if
num_prefills
>
0
:
if
(
isinstance
(
attn_metadata
,
attn_metadata_instances
=
get_platform_metadata_classes
()
(
FlashAttentionMetadata
,
XFormersMetadata
,
if
(
isinstance
(
attn_metadata
,
attn_metadata_instances
)
PlaceholderAttentionMetadata
))
and
attn_metadata
.
context_lens_tensor
is
not
None
):
and
attn_metadata
.
context_lens_tensor
is
not
None
):
has_initial_states
=
\
has_initial_states
=
\
attn_metadata
.
context_lens_tensor
[:
num_prefills
]
>
0
#[batch,]
attn_metadata
.
context_lens_tensor
[:
num_prefills
]
>
0
#[batch,]
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
4eabe123
...
@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
@
CustomOp
.
register
(
"mixer2_gated_rms_norm"
)
@
CustomOp
.
register
(
"mixer2_gated_rms_norm"
)
class
Mixer2RMSNormGated
(
CustomOp
):
class
Mixer2RMSNormGated
(
CustomOp
):
def
__init__
(
self
,
full_hidden_size
,
full_n_groups
,
eps
=
1e-6
):
def
__init__
(
self
,
full_hidden_size
:
int
,
full_n_groups
:
int
,
use_rms_norm
:
bool
=
True
,
eps
:
float
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
...
@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
self
.
n_groups
=
full_hidden_size
//
self
.
group_size
self
.
n_groups
=
full_hidden_size
//
self
.
group_size
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
self
.
use_rms_norm
=
use_rms_norm
if
self
.
use_rms_norm
:
# Register norm weight only if we're actually applying RMSNorm
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
self
.
per_rank_hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
self
.
per_rank_hidden_size
))
set_weight_attrs
(
self
.
weight
,
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
assert
self
.
full_hidden_size
%
self
.
tp_size
==
0
,
\
else
:
"Tensor parallel world size must divide hidden size."
# Avoid checkpoint mismatch by skipping unused parameter
self
.
register_parameter
(
"weight"
,
None
)
assert
(
self
.
full_hidden_size
%
self
.
tp_size
==
0
),
"Tensor parallel world size must divide hidden size."
def
forward_native
(
def
forward_native
(
self
,
self
,
...
@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
...
@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
# the input and then redundantly compute the RMSNorm.
# the input and then redundantly compute the RMSNorm.
input_dtype
=
x
.
dtype
input_dtype
=
x
.
dtype
x
=
x
*
nn
.
functional
.
silu
(
gate
.
to
(
torch
.
float32
))
x
=
x
*
nn
.
functional
.
silu
(
gate
.
to
(
torch
.
float32
))
if
not
self
.
use_rms_norm
:
return
x
.
to
(
input_dtype
)
if
self
.
n_groups
==
1
:
if
self
.
n_groups
==
1
:
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
...
@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
global_sums
=
tensor_model_parallel_all_reduce
(
local_sums
)
global_sums
=
tensor_model_parallel_all_reduce
(
local_sums
)
# Calculate the variance
# Calculate the variance
count
=
self
.
tp_size
*
x
.
shape
[
-
1
]
count
=
self
.
tp_size
*
x
.
shape
[
-
1
]
variance
=
(
global_sums
/
count
)
variance
=
global_sums
/
count
else
:
else
:
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
...
@@ -105,6 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
...
@@ -105,6 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
gate
:
torch
.
Tensor
,
gate
:
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
input_dtype
=
x
.
dtype
if
not
self
.
use_rms_norm
:
# Keep gate in float32 for numerical stability during silu
return
x
*
nn
.
functional
.
silu
(
gate
.
to
(
torch
.
float32
)).
to
(
input_dtype
)
if
self
.
tp_size
>
1
or
self
.
n_groups
!=
1
:
if
self
.
tp_size
>
1
or
self
.
n_groups
!=
1
:
return
self
.
forward_native
(
x
,
gate
)
return
self
.
forward_native
(
x
,
gate
)
...
@@ -182,13 +199,15 @@ def mamba_v2_sharded_weight_loader(
...
@@ -182,13 +199,15 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well.
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
# https://github.com/python/mypy/issues/2410
param
.
data
[
param
.
data
[
boundary
:(
boundary
+
take
),
# type: ignore[misc]
boundary
:(
boundary
+
take
),
...]
=
loaded_weight
[
loaded_start_idx
:(
# type: ignore[misc]
...
# type: ignore[misc]
loaded_start_idx
+
take
)]
# type: ignore[misc]
]
=
loaded_weight
[
loaded_start_idx
:(
loaded_start_idx
+
take
)
# type: ignore[misc]
]
# type: ignore[misc]
# move indexing boundaries
# move indexing boundaries
boundary
+=
shard_size
boundary
+=
shard_size
loaded_boundary
+=
(
full_dim
-
extra
)
loaded_boundary
+=
full_dim
-
extra
return
loader
return
loader
...
@@ -206,7 +225,8 @@ class MambaMixer2(CustomOp):
...
@@ -206,7 +225,8 @@ class MambaMixer2(CustomOp):
**selective** state spaces)
**selective** state spaces)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
:
int
,
hidden_size
:
int
,
ssm_state_size
:
int
,
ssm_state_size
:
int
,
conv_kernel_size
:
int
,
conv_kernel_size
:
int
,
...
@@ -217,8 +237,10 @@ class MambaMixer2(CustomOp):
...
@@ -217,8 +237,10 @@ class MambaMixer2(CustomOp):
num_heads
:
int
=
128
,
num_heads
:
int
=
128
,
head_dim
:
int
=
64
,
head_dim
:
int
=
64
,
rms_norm_eps
:
float
=
1e-5
,
rms_norm_eps
:
float
=
1e-5
,
activation
=
"silu"
,
activation
:
str
=
"silu"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
use_rms_norm
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
super
().
__init__
()
# For TP, the sharding plan is as follows:
# For TP, the sharding plan is as follows:
...
@@ -238,17 +260,16 @@ class MambaMixer2(CustomOp):
...
@@ -238,17 +260,16 @@ class MambaMixer2(CustomOp):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
assert
num_heads
%
self
.
tp_size
==
0
,
\
assert
(
num_heads
%
self
.
tp_size
==
0
"Tensor parallel world size must divide num heads."
),
"Tensor parallel world size must divide num heads."
assert
(
n_groups
%
self
.
tp_size
)
==
0
or
n_groups
==
1
,
\
assert
(
n_groups
%
self
.
tp_size
)
==
0
or
n_groups
==
1
,
(
(
"If tensor parallel world size does not divide num_heads, "
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1."
"then num_groups must equal 1."
)
)
assert
self
.
tp_size
==
1
or
quant_config
is
None
,
\
assert
(
"Tensor parallel currently not supported for quantized models."
self
.
tp_size
==
1
or
quant_config
is
None
),
"Tensor parallel currently not supported for quantized models."
self
.
ssm_state_size
=
ssm_state_size
self
.
ssm_state_size
=
ssm_state_size
self
.
activation
=
activation
self
.
activation
=
activation
...
@@ -265,8 +286,7 @@ class MambaMixer2(CustomOp):
...
@@ -265,8 +286,7 @@ class MambaMixer2(CustomOp):
self
.
n_groups
=
n_groups
+
extra_groups_for_head_shards
(
self
.
n_groups
=
n_groups
+
extra_groups_for_head_shards
(
n_groups
,
self
.
tp_size
)
n_groups
,
self
.
tp_size
)
self
.
conv_dim
=
(
intermediate_size
+
self
.
conv_dim
=
intermediate_size
+
2
*
self
.
n_groups
*
ssm_state_size
2
*
self
.
n_groups
*
ssm_state_size
)
self
.
conv1d
=
ColumnParallelLinear
(
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
conv_kernel_size
,
input_size
=
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
output_size
=
self
.
conv_dim
,
...
@@ -279,11 +299,12 @@ class MambaMixer2(CustomOp):
...
@@ -279,11 +299,12 @@ class MambaMixer2(CustomOp):
# doesn't allow to override it
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
self
.
in_proj
=
ColumnParallelLinear
(
output_size
=
intermediate
_size
+
input_size
=
hidden
_size
,
self
.
conv_dim
+
self
.
num_heads
,
output_size
=
intermediate_size
+
self
.
conv_dim
+
self
.
num_heads
,
bias
=
use_bias
,
bias
=
use_bias
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
)
# - because in_proj is a concatenation of 3 weights, we
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# need to interleave them before sharding
...
@@ -305,7 +326,8 @@ class MambaMixer2(CustomOp):
...
@@ -305,7 +326,8 @@ class MambaMixer2(CustomOp):
# - ditto for the otther two weights below
# - ditto for the otther two weights below
delattr
(
self
.
conv1d
.
bias
,
"weight_loader"
)
delattr
(
self
.
conv1d
.
bias
,
"weight_loader"
)
set_weight_attrs
(
set_weight_attrs
(
self
.
conv1d
.
bias
,
{
self
.
conv1d
.
bias
,
{
"weight_loader"
:
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
mamba_v2_sharded_weight_loader
(
[
[
...
@@ -316,18 +338,25 @@ class MambaMixer2(CustomOp):
...
@@ -316,18 +338,25 @@ class MambaMixer2(CustomOp):
self
.
tp_size
,
self
.
tp_size
,
tp_rank
,
tp_rank
,
)
)
})
},
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
self
.
conv1d
.
weight
,
{
"weight_loader"
:
"weight_loader"
:
mamba_v2_sharded_weight_loader
([
mamba_v2_sharded_weight_loader
(
[
intermediate_settings
,
intermediate_settings
,
group_shard_settings
,
group_shard_settings
,
group_shard_settings
,
group_shard_settings
,
],
self
.
tp_size
,
tp_rank
)
],
})
self
.
tp_size
,
tp_rank
,
)
},
)
if
quant_config
is
None
:
if
quant_config
is
None
:
# - quant layers do not have a weight loader
# - quant layers do not have a weight loader
...
@@ -345,8 +374,10 @@ class MambaMixer2(CustomOp):
...
@@ -345,8 +374,10 @@ class MambaMixer2(CustomOp):
head_setings
,
# for dt
head_setings
,
# for dt
],
],
self
.
tp_size
,
self
.
tp_size
,
tp_rank
)
tp_rank
,
})
)
},
)
# - these are TPed by heads to reduce the size of the
# - these are TPed by heads to reduce the size of the
# temporal shape
# temporal shape
...
@@ -357,6 +388,7 @@ class MambaMixer2(CustomOp):
...
@@ -357,6 +388,7 @@ class MambaMixer2(CustomOp):
))
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
self
.
use_rms_norm
=
use_rms_norm
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
a_weight_loader
=
composed_weight_loader
(
...
@@ -365,18 +397,25 @@ class MambaMixer2(CustomOp):
...
@@ -365,18 +397,25 @@ class MambaMixer2(CustomOp):
set_weight_attrs
(
self
.
dt_bias
,
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
out_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
out_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
use_bias
,
bias
=
use_bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
)
self
.
norm
=
Mixer2RMSNormGated
(
intermediate_size
,
self
.
norm
=
Mixer2RMSNormGated
(
intermediate_size
,
n_groups
,
n_groups
,
self
.
use_rms_norm
,
eps
=
rms_norm_eps
)
eps
=
rms_norm_eps
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_native
(
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
self
,
hidden_states
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
):
pass
pass
def
forward_cuda
(
def
forward_cuda
(
...
@@ -384,6 +423,7 @@ class MambaMixer2(CustomOp):
...
@@ -384,6 +423,7 @@ class MambaMixer2(CustomOp):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
mamba2_metadata
:
Mamba2Metadata
,
mamba2_metadata
:
Mamba2Metadata
,
mup_vector
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# kernels to operate in continuous batching and in chunked prefill
...
@@ -401,6 +441,10 @@ class MambaMixer2(CustomOp):
...
@@ -401,6 +441,10 @@ class MambaMixer2(CustomOp):
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
if
mup_vector
is
not
None
:
projected_states
=
projected_states
*
mup_vector
gate
,
hidden_states_B_C
,
dt
=
torch
.
split
(
gate
,
hidden_states_B_C
,
dt
=
torch
.
split
(
projected_states
,
projected_states
,
[
[
...
@@ -561,6 +605,9 @@ class MambaMixer2(CustomOp):
...
@@ -561,6 +605,9 @@ class MambaMixer2(CustomOp):
hidden_states
=
torch
.
vstack
(
ssd_output_list
)
hidden_states
=
torch
.
vstack
(
ssd_output_list
)
# 4. gated MLP
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
# 5. Final linear projection
# 5. Final linear projection
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
4eabe123
...
@@ -14,7 +14,7 @@ QuantizationMethods = Literal[
...
@@ -14,7 +14,7 @@ QuantizationMethods = Literal[
"ptpc_fp8"
,
"ptpc_fp8"
,
"fbgemm_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
"modelopt"
,
"
nv
fp4"
,
"
modelopt_
fp4"
,
"marlin"
,
"marlin"
,
"bitblas"
,
"bitblas"
,
"gguf"
,
"gguf"
,
...
@@ -120,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -120,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"
nv
fp4"
:
ModelOptNvFp4Config
,
"
modelopt_
fp4"
:
ModelOptNvFp4Config
,
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
"bitblas"
:
BitBLASConfig
,
"bitblas"
:
BitBLASConfig
,
"gguf"
:
GGUFConfig
,
"gguf"
:
GGUFConfig
,
...
...
Prev
1
…
22
23
24
25
26
27
28
29
30
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment