Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd5eba4c
Unverified
Commit
dd5eba4c
authored
Nov 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 27, 2024
Browse files
Remove fused_moe_grok (#2223)
parent
a4fd2f9b
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
12 additions
and
1372 deletions
+12
-1372
3rdparty/amd/tuning/benchmark_moe_rocm.py
3rdparty/amd/tuning/benchmark_moe_rocm.py
+1
-1
python/sglang/srt/layers/fused_moe_grok/__init__.py
python/sglang/srt/layers/fused_moe_grok/__init__.py
+0
-1
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
+0
-692
python/sglang/srt/layers/fused_moe_grok/layer.py
python/sglang/srt/layers/fused_moe_grok/layer.py
+0
-630
python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
...,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
+0
-0
python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
...,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
+0
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+11
-48
No files found.
3rdparty/amd/tuning/benchmark_moe_rocm.py
View file @
dd5eba4c
...
@@ -10,7 +10,7 @@ import triton.language as tl
...
@@ -10,7 +10,7 @@ import triton.language as tl
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_
grok
.fused_moe
import
fused_moe
,
get_config_file_name
from
sglang.srt.layers.fused_moe_
triton
.fused_moe
import
fused_moe
,
get_config_file_name
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
...
...
python/sglang/srt/layers/fused_moe_grok/__init__.py
deleted
100644 → 0
View file @
a4fd2f9b
from
sglang.srt.layers.fused_moe_grok.layer
import
FusedMoE
,
FusedMoEMethodBase
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
deleted
100644 → 0
View file @
a4fd2f9b
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/fused_moe_grok/layer.py
deleted
100644 → 0
View file @
a4fd2f9b
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/fused_moe_
grok
/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
→
python/sglang/srt/layers/fused_moe_
triton
/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
View file @
dd5eba4c
File moved
python/sglang/srt/layers/fused_moe_
grok
/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
→
python/sglang/srt/layers/fused_moe_
triton
/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
View file @
dd5eba4c
File moved
python/sglang/srt/models/grok.py
View file @
dd5eba4c
...
@@ -16,22 +16,17 @@
...
@@ -16,22 +16,17 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
"""Inference-only Grok1 model."""
import
warnings
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
get_tensor_model_parallel_world_size
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.fused_moe_
grok
import
FusedMoE
from
sglang.srt.layers.fused_moe_
triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
...
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
self
.
use_presharded_weights
=
True
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_presharded_weights
:
extra_kwargs
=
{
"use_presharded_weights"
:
self
.
use_presharded_weights
}
else
:
extra_kwargs
=
{}
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
param
,
param
,
loaded_weight
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
**
extra_kwargs
,
)
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
if
name
is
None
:
if
name
is
None
:
continue
continue
...
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
import
glob
import
os
if
get_tensor_model_parallel_world_size
()
==
1
:
return
old_prepare_weights
(
self
,
model_name_or_path
,
revision
,
fall_back_to_pt
)
tp_rank
=
get_tensor_model_parallel_rank
()
allow_patterns
=
[
f
"*-
{
tp_rank
:
03
d
}
.bin"
]
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
use_safetensors
=
False
return
hf_folder
,
hf_weights_files
,
use_safetensors
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment