Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4322c31e
Unverified
Commit
4322c31e
authored
May 01, 2025
by
ryang
Committed by
GitHub
May 01, 2025
Browse files
Support XiaomiMiMo/MiMo model inference (#5921)
parent
9858113c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
0 deletions
+172
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+1
-0
python/sglang/srt/models/xiaomi_mimo.py
python/sglang/srt/models/xiaomi_mimo.py
+171
-0
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
4322c31e
...
...
@@ -107,6 +107,7 @@ class FlashInferAttnBackend(AttentionBackend):
if
(
"Qwen2ForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
or
"Qwen3ForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
or
"MiMoForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
):
global_config
.
flashinfer_workspace_size
=
512
*
1024
*
1024
...
...
python/sglang/srt/models/xiaomi_mimo.py
0 → 100644
View file @
4322c31e
# Adapted from qwen2.py
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2DecoderLayer
,
Qwen2MLP
,
Qwen2Model
from
sglang.srt.utils
import
add_prefix
MiMoConfig
=
None
class
MiMoModel
(
Qwen2Model
):
def
__init__
(
self
,
config
:
MiMoConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
decoder_layer_type
=
Qwen2DecoderLayer
,
)
class
MiMoForCausalLM
(
nn
.
Module
):
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
MiMoConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
MiMoModel
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
(
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
or
"mtp_layers"
in
name
):
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
EntryClass
=
MiMoForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment