Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a59636bb
"vscode:/vscode.git/clone" did not exist on "ccaa0bf282683dc1647147271f63ffbd0e972603"
Unverified
Commit
a59636bb
authored
Aug 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 14, 2024
Browse files
Update grok 1 model (#1095)
parent
fe502432
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
813 additions
and
513 deletions
+813
-513
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+3
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+0
-1
python/sglang/srt/layers/fused_moe/__init__.py
python/sglang/srt/layers/fused_moe/__init__.py
+1
-0
python/sglang/srt/layers/fused_moe/fused_moe.py
python/sglang/srt/layers/fused_moe/fused_moe.py
+165
-108
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+587
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+49
-395
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+0
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-2
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
a59636bb
...
...
@@ -88,6 +88,9 @@ def main(args):
for
i
in
range
(
len
(
states
)):
preds
.
append
(
get_answer_value
(
states
[
i
][
"answer"
]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
...
...
python/sglang/bench_latency.py
View file @
a59636bb
...
...
@@ -221,6 +221,7 @@ def correctness_test(
# Prepare inputs
input_ids
,
reqs
=
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
)
rank_print
(
f
"
{
input_ids
=
}
"
)
if
bench_args
.
cut_len
>
0
:
# Prefill
...
...
python/sglang/srt/layers/activation.py
View file @
a59636bb
...
...
@@ -14,7 +14,6 @@ limitations under the License.
"""Fused operators for activation layers."""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
flashinfer.activation
import
silu_and_mul
from
vllm.model_executor.custom_op
import
CustomOp
...
...
python/sglang/srt/layers/fused_moe/__init__.py
0 → 100644
View file @
a59636bb
from
sglang.srt.layers.fused_moe.layer
import
FusedMoE
,
FusedMoEMethodBase
python/sglang/srt/layers/fused_moe.py
→
python/sglang/srt/layers/fused_moe
/fused_moe
.py
View file @
a59636bb
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from
# https://github.com/vllm-project/vllm/
blob/c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/layers/fused_moe
/fused_moe.py#L1
# https://github.com/vllm-project/vllm/
tree/v0.5.4
/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel."""
import
functools
import
json
...
...
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
triton
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
...
...
@@ -373,6 +359,31 @@ def get_default_config(
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
return
config
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
...
@@ -403,6 +414,41 @@ def fused_topk(
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -425,24 +471,23 @@ def fused_experts(
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
)
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
config
=
get_config_func
(
M
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
...
...
@@ -460,56 +505,85 @@ def fused_experts(
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
),
)
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache
3
,
a2
_scale
,
w2
_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
Tru
e
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
intermediate_cache
1
,
a1
_scale
,
w1
_scale
,
curr_
topk_weights
,
curr_
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
Fals
e
,
topk_ids
.
shape
[
1
]
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
return
torch
.
sum
(
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
,
out
=
out_
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
return
out_hidden_states
def
fused_moe
(
...
...
@@ -521,6 +595,9 @@ def fused_moe(
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -543,6 +620,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...
...
@@ -556,12 +637,18 @@ def fused_moe(
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
hasattr
(
ops
,
"topk_softmax"
):
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
)
else
:
topk_weights
,
topk_ids
=
fused_topk
_v0_4_3
(
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
...
...
@@ -579,33 +666,3 @@ def fused_moe(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
fused_topk_v0_4_3
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
import
vllm._moe_C
as
moe_kernels
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
python/sglang/srt/layers/fused_moe/layer.py
0 → 100644
View file @
a59636bb
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/logits_processor.py
View file @
a59636bb
...
...
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
last_logits
/=
self
.
config
.
final_logit_softcapping
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
last_logits
=
torch
.
tanh
(
last_logits
)
last_logits
*=
self
.
config
.
final_logit_softcapping
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
...
...
@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
all_logits
/=
self
.
config
.
final_logit_softcapping
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
all_logits
=
torch
.
tanh
(
all_logits
)
all_logits
*=
self
.
config
.
final_logit_softcapping
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a59636bb
...
...
@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
is_generation_model
,
is_llama3_405b_fp8
,
is_llama3_405b_fp8
_head_16
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
...
...
@@ -158,7 +158,7 @@ class ModelRunner:
skip_tokenizer_init
=
True
,
)
if
is_llama3_405b_fp8
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
if
is_llama3_405b_fp8
_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
...
...
python/sglang/srt/models/grok.py
View file @
a59636bb
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/mixtral.py
View file @
a59636bb
...
...
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
...
...
python/sglang/srt/utils.py
View file @
a59636bb
...
...
@@ -35,7 +35,6 @@ import torch
import
torch.distributed
as
dist
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
torch.nn.parameter
import
Parameter
from
triton.runtime.cache
import
(
FileCacheManager
,
...
...
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
def
is_llama3_405b_fp8
(
model_config
):
def
is_llama3_405b_fp8
_head_16
(
model_config
):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment