Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07c43530
Unverified
Commit
07c43530
authored
Feb 25, 2025
by
Michael Goin
Committed by
GitHub
Feb 26, 2025
Browse files
[Model] Support Grok1 (#13795)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
34e3494e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
634 additions
and
17 deletions
+634
-17
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+5
-0
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+31
-12
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+17
-5
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+4
-0
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+2
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+3
-0
vllm/model_executor/models/grok1.py
vllm/model_executor/models/grok1.py
+565
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/source/models/supported_models.md
View file @
07c43530
...
@@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ
*
`parasail-ai/GritLM-7B-vllm`
.
*
`parasail-ai/GritLM-7B-vllm`
.
*
✅︎
*
✅︎
*
✅︎
*
✅︎
-
*
`Grok1ModelForCausalLM`
*
Grok1
*
`hpcai-tech/grok-1`
.
*
✅︎
*
✅︎
-
*
`InternLMForCausalLM`
-
*
`InternLMForCausalLM`
*
InternLM
*
InternLM
*
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc.
*
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc.
...
...
tests/models/registry.py
View file @
07c43530
...
@@ -130,6 +130,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -130,6 +130,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-160m"
),
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-160m"
),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"Grok1ModelForCausalLM"
:
_HfExamplesInfo
(
"hpcai-tech/grok-1"
,
trust_remote_code
=
True
),
"InternLMForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm-chat-7b"
,
"InternLMForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm-chat-7b"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"InternLM2ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm2-chat-7b"
,
"InternLM2ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm2-chat-7b"
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
07c43530
...
@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
...
@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1093,6 +1096,7 @@ def outplace_fused_experts(
...
@@ -1093,6 +1096,7 @@ def outplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1106,7 +1110,7 @@ def outplace_fused_experts(
...
@@ -1106,7 +1110,7 @@ def outplace_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
a2_scale
,
block_shape
)
...
@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
...
@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,
if
inplace
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
return
hidden_states
return
hidden_states
else
:
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
...
@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
if
activation
==
"silu"
:
intermediate_cache1
.
view
(
-
1
,
N
))
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
w2
,
...
@@ -1339,6 +1354,7 @@ def fused_moe(
...
@@ -1339,6 +1354,7 @@ def fused_moe(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
...
@@ -1370,6 +1386,8 @@ def fused_moe(
...
@@ -1370,6 +1386,8 @@ def fused_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
...
@@ -1420,6 +1438,7 @@ def fused_moe(
...
@@ -1420,6 +1438,7 @@ def fused_moe(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
07c43530
...
@@ -120,7 +120,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -120,7 +120,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
layer
=
layer
,
...
@@ -134,7 +135,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -134,7 +135,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map
=
expert_map
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
...
@@ -150,7 +152,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -150,7 +152,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -170,6 +173,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -170,6 +173,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
expert_map
=
expert_map
)
...
@@ -186,9 +190,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -186,9 +190,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
**
kwargs
,
**
kwargs
,
):
):
assert
custom_routing_function
is
None
assert
custom_routing_function
is
None
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported."
return
layer
.
ipex_fusion
(
return
layer
.
ipex_fusion
(
x
,
x
,
use_grouped_topk
,
use_grouped_topk
,
...
@@ -213,7 +219,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -213,7 +219,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
num_expert_group
is
None
...
@@ -225,6 +232,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -225,6 +232,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if
e_score_correction_bias
is
not
None
:
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Expert score correction bias is not supported for TPU."
)
"Expert score correction bias is not supported for TPU."
)
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported for TPU."
return
fused_moe_pallas
(
hidden_states
=
x
,
return
fused_moe_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
...
@@ -277,6 +285,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -277,6 +285,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -305,6 +314,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -305,6 +314,7 @@ class FusedMoE(torch.nn.Module):
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
activation
=
activation
self
.
expert_map
=
None
self
.
expert_map
=
None
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
...
@@ -653,7 +663,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -653,7 +663,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
)
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
)
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
# Default set to False. (May have to add shared expert outputs.)
# Default set to False. (May have to add shared expert outputs.)
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
07c43530
...
@@ -469,7 +469,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -469,7 +469,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"Expert Parallelism is not supported for "
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
07c43530
...
@@ -219,6 +219,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -219,6 +219,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -240,6 +241,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -240,6 +241,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
@@ -550,7 +552,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -550,7 +552,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"Expert Parallelism is not supported for "
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
07c43530
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -134,6 +135,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -134,6 +135,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_int8_w8a16
=
True
,
use_int8_w8a16
=
True
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
07c43530
...
@@ -675,6 +675,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -675,6 +675,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -698,6 +699,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -698,6 +699,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
07c43530
...
@@ -590,7 +590,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -590,7 +590,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
# The input must currently be float16
# The input must currently be float16
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
x
=
x
.
half
()
...
...
vllm/model_executor/models/grok1.py
0 → 100644
View file @
07c43530
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
# Default Grok1-specific constants, overridden by config values if present
DEFAULT_ATTN_OUTPUT_MULTIPLIER
=
0.08838834764831845
DEFAULT_OUTPUT_MULTIPLIER_SCALE
=
0.5773502691896257
DEFAULT_EMBEDDING_MULTIPLIER_SCALE
=
78.38367176906169
class
Grok1MoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
activation
=
"gelu"
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
30.0
*
F
.
tanh
(
router_logits
/
30.0
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
Grok1Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
config
=
None
,
# Added config parameter
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
config
=
config
# Store config reference
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
attn_logits_soft_cap
=
max
(
getattr
(
config
,
"attn_logit_softcapping"
,
30.0
),
0.0
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
logits_soft_cap
=
attn_logits_soft_cap
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
# Apply attention output multiplier if specified in config
attn_multiplier
=
getattr
(
self
.
config
,
"attn_output_multiplier"
,
None
)
if
self
.
config
else
None
if
attn_multiplier
is
not
None
:
output
=
output
*
attn_multiplier
return
output
class
Grok1DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Check for fp8 quantization
self
.
use_fp8
=
False
if
quant_config
is
not
None
:
self
.
use_fp8
=
getattr
(
quant_config
,
"is_fp8_w8a8"
,
lambda
:
False
)()
if
not
self
.
use_fp8
and
hasattr
(
quant_config
,
"is_fp8"
):
self
.
use_fp8
=
quant_config
.
is_fp8
# Requires transformers > 4.32.0
# Default rope_theta value if not in config
rope_theta
=
10000
self
.
attn
=
Grok1Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
config
=
config
)
# Pass config to Grok1Attention
# Grok1 uses "num_experts" in its config
num_experts
=
getattr
(
config
,
"num_experts"
,
8
)
num_experts_per_tok
=
getattr
(
config
,
"num_experts_per_tok"
,
2
)
self
.
moe_block
=
Grok1MoE
(
num_experts
=
num_experts
,
top_k
=
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe_block"
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
pre_attn_norm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
pre_attn_norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Post attention normalization
hidden_states
=
self
.
post_attn_norm
(
hidden_states
)
# MoE block with normalization
hidden_states
,
residual
=
self
.
pre_moe_norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
moe_block
(
hidden_states
)
hidden_states
=
self
.
post_moe_norm
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
Grok1Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embedding_multiplier_scale
=
getattr
(
config
,
"embedding_multiplier_scale"
,
DEFAULT_EMBEDDING_MULTIPLIER_SCALE
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Grok1DecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
hidden_states
*
self
.
embedding_multiplier_scale
return
hidden_states
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Grok1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
output_multiplier_scale
=
getattr
(
config
,
"output_multiplier_scale"
,
DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
self
.
output_multiplier_scale
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
8
)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"linear"
,
# Grok1 specific
ckpt_down_proj_name
=
"linear_1"
,
# Grok1 specific
ckpt_up_proj_name
=
"linear_v"
,
# Grok1 specific
num_experts
=
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
# Handle Grok1-specific norm.scale naming
if
"norm.scale"
in
name
:
name
=
name
.
replace
(
"scale"
,
"weight"
)
# Skip lm_head when tie_word_embeddings is True
if
"lm_head"
in
name
and
self
.
config
.
tie_word_embeddings
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/registry.py
View file @
07c43530
...
@@ -60,6 +60,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -60,6 +60,7 @@ _TEXT_GENERATION_MODELS = {
"GraniteForCausalLM"
:
(
"granite"
,
"GraniteForCausalLM"
),
"GraniteForCausalLM"
:
(
"granite"
,
"GraniteForCausalLM"
),
"GraniteMoeForCausalLM"
:
(
"granitemoe"
,
"GraniteMoeForCausalLM"
),
"GraniteMoeForCausalLM"
:
(
"granitemoe"
,
"GraniteMoeForCausalLM"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"Grok1ModelForCausalLM"
:
(
"grok1"
,
"Grok1ForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternLM2VEForCausalLM"
:
(
"internlm2_ve"
,
"InternLM2VEForCausalLM"
),
"InternLM2VEForCausalLM"
:
(
"internlm2_ve"
,
"InternLM2VEForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment