Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
75ce37f4
Unverified
Commit
75ce37f4
authored
Aug 26, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 26, 2024
Browse files
Move sampler into CUDA graph (#1201)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
97589a60
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
230 additions
and
87 deletions
+230
-87
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+68
-15
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+20
-8
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+32
-20
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+24
-9
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+8
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-3
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+4
-12
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+5
-1
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+5
-1
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+5
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+5
-1
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+5
-1
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+5
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+5
-1
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+5
-1
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+7
-3
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+2
-2
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+5
-1
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
75ce37f4
...
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
...
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LogitProcessorOutput
:
class
Logit
s
ProcessorOutput
:
# The logits of the next tokens. shape: [#seq, vocab_size]
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits
:
torch
.
Tensor
next_token_logits
:
torch
.
Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
# The logprobs of the next tokens. shape: [#seq, vocab_size]
...
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
...
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
if
not
logits_metadata
.
return_logprob
:
return
LogitProcessorOutput
(
return
Logit
s
ProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
...
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
...
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
else
:
else
:
output_top_logprobs
=
None
output_top_logprobs
=
None
return
LogitProcessorOutput
(
return
Logit
s
ProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
...
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
...
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
# Remove the last token logprob for the prefill tokens.
# Remove the last token logprob for the prefill tokens.
input_token_logprobs
=
input_token_logprobs
[:
-
1
]
input_token_logprobs
=
input_token_logprobs
[:
-
1
]
return
LogitProcessorOutput
(
return
Logit
s
ProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
...
...
python/sglang/srt/layers/sampler.py
View file @
75ce37f4
import
dataclasses
import
logging
import
logging
from
typing
import
Union
import
torch
import
torch
from
flashinfer.sampling
import
(
from
flashinfer.sampling
import
(
...
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
...
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
)
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
# TODO: move this dict to another place
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
SampleOutput
:
success
:
torch
.
Tensor
probs
:
torch
.
Tensor
batch_next_token_ids
:
torch
.
Tensor
class
Sampler
(
CustomOp
):
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
return
logits
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
,
is_torch_compile
:
bool
=
False
,
):
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
logits
.
div_
(
sampling_info
.
temperatures
)
if
is_torch_compile
:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
if
sampling_info
.
logit_bias
is
not
None
:
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
s
ampling_info
.
penalizer_orchestrator
.
apply
(
logits
)
logits
=
s
elf
.
_apply_penalties
(
logits
,
sampling_info
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
)
if
sampling_info
.
min_ps
.
any
()
:
if
sampling_info
.
need_min_p_sampling
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
...
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
...
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
)
if
not
torch
.
all
(
success
):
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
return
batch_next_token_ids
def
forward_native
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
sampling_info
:
SamplingBatchInfo
,
):
if
isinstance
(
logits
,
LogitsProcessorOutput
):
logits
=
logits
.
next_token_logits
probs
=
self
.
_get_probs
(
logits
,
sampling_info
,
is_torch_compile
=
True
)
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
def
forward_native
():
return
SampleOutput
(
success
,
probs
,
batch_next_token_ids
)
raise
NotImplementedError
(
"Native forward is not implemented yet."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
def
top_k_top_p_min_p_sampling_from_probs_torch
(
...
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# FIXME: torch.multiomial does not support num_samples = 1
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
2
,
replacement
=
True
)[
:,
:
1
]
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
batch_next_token_ids
=
torch
.
zeros
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
75ce37f4
from
__future__
import
annotations
"""
"""
Copyright 2023-2024 SGLang Team
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -17,7 +19,7 @@ limitations under the License.
...
@@ -17,7 +19,7 @@ limitations under the License.
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
if
TYPE_CHECKING
:
from
sglang.srt.layers.sampler
import
SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
# Put some global args for easy access
...
@@ -671,11 +677,17 @@ class ScheduleBatch:
...
@@ -671,11 +677,17 @@ class ScheduleBatch:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
check_sample_results
(
self
,
sample_output
:
SampleOutput
):
from
sglang.srt.layers.sampler
import
Sampler
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
sampler
=
Sampler
()
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
batch_next_token_ids
=
sampler
(
logits
,
self
.
sampling_info
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
python/sglang/srt/managers/tp_worker.py
View file @
75ce37f4
...
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
...
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
Logit
s
ProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
,
...
@@ -486,21 +486,29 @@ class ModelTpServer:
...
@@ -486,21 +486,29 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
)
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
output
.
next_token_logprobs
=
output
.
next_token_logprobs
[
logits_output
.
next_token_logprobs
=
(
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
logits_output
.
next_token_logprobs
[
next_token_ids
,
torch
.
arange
(
].
tolist
()
len
(
next_token_ids
),
device
=
next_token_ids
.
device
output
.
input_token_logprobs
=
output
.
input_token_logprobs
.
tolist
()
),
output
.
normalized_prompt_logprobs
=
(
next_token_ids
,
output
.
normalized_prompt_logprobs
.
tolist
()
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
...
@@ -539,12 +547,14 @@ class ModelTpServer:
...
@@ -539,12 +547,14 @@ class ModelTpServer:
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
logits_output
)
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
else
:
else
:
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
logits_
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
embeddings
=
output
.
embeddings
.
tolist
()
embeddings
=
logits_
output
.
embeddings
.
tolist
()
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -572,7 +582,7 @@ class ModelTpServer:
...
@@ -572,7 +582,7 @@ class ModelTpServer:
req
:
Req
,
req
:
Req
,
pt
:
int
,
pt
:
int
,
next_token_ids
:
List
[
int
],
next_token_ids
:
List
[
int
],
output
:
LogitProcessorOutput
,
output
:
Logit
s
ProcessorOutput
,
):
):
if
req
.
normalized_prompt_logprob
is
None
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
@@ -654,15 +664,17 @@ class ModelTpServer:
...
@@ -654,15 +664,17 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
)
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
logits_
output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
output
.
next_token_logprobs
[
next_token_logprobs
=
logits_
output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
next_token_ids
,
].
tolist
()
].
tolist
()
...
@@ -688,7 +700,7 @@ class ModelTpServer:
...
@@ -688,7 +700,7 @@ class ModelTpServer:
(
next_token_logprobs
[
i
],
next_token_id
)
(
next_token_logprobs
[
i
],
next_token_id
)
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs
.
append
(
logits_
output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
75ce37f4
...
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
...
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
(
from
sglang.srt.layers.logits_processor
import
(
LogitProcessorOutput
,
LogitsMetadata
,
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessor
,
LogitsProcessorOutput
,
)
)
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
ForwardMode
,
InputMetadata
,
InputMetadata
,
update_flashinfer_indices
,
update_flashinfer_indices
,
)
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
@@ -143,6 +145,10 @@ class CudaGraphRunner:
...
@@ -143,6 +145,10 @@ class CudaGraphRunner:
self
.
flashinfer_kv_indices
.
clone
(),
self
.
flashinfer_kv_indices
.
clone
(),
]
]
# Sampling inputs
vocab_size
=
model_runner
.
model_config
.
vocab_size
self
.
sampling_info
=
SamplingBatchInfo
.
dummy_one
(
self
.
max_bs
,
vocab_size
)
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
if
use_torch_compile
:
if
use_torch_compile
:
...
@@ -234,6 +240,7 @@ class CudaGraphRunner:
...
@@ -234,6 +240,7 @@ class CudaGraphRunner:
def
run_once
():
def
run_once
():
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
sampling_info
=
self
.
sampling_info
[:
bs
],
batch_size
=
bs
,
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -298,27 +305,35 @@ class CudaGraphRunner:
...
@@ -298,27 +305,35 @@ class CudaGraphRunner:
self
.
flashinfer_handlers
[
bs
],
self
.
flashinfer_handlers
[
bs
],
)
)
# Sampling inputs
self
.
sampling_info
.
inplace_assign
(
raw_bs
,
batch
.
sampling_info
)
# Replay
# Replay
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
output
=
self
.
output_buffers
[
bs
]
sample_output
,
logits_
output
=
self
.
output_buffers
[
bs
]
# Unpad
# Unpad
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
output
=
LogitProcessorOutput
(
logits_
output
=
Logit
s
ProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logits
=
logits_
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
None
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
input_token_logprobs
=
None
,
input_token_logprobs
=
None
,
input_top_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
)
sample_output
=
SampleOutput
(
sample_output
.
success
[:
raw_bs
],
sample_output
.
probs
[:
raw_bs
],
sample_output
.
batch_next_token_ids
[:
raw_bs
],
)
# Extract logprobs
# Extract logprobs
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_
output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
output
.
next_token_logits
,
dim
=-
1
logits_
output
.
next_token_logits
,
dim
=-
1
)
)
return_top_logprob
=
any
(
x
>
0
for
x
in
batch
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
batch
.
top_logprobs_nums
)
if
return_top_logprob
:
if
return_top_logprob
:
...
@@ -326,8 +341,8 @@ class CudaGraphRunner:
...
@@ -326,8 +341,8 @@ class CudaGraphRunner:
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
)
output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
logits_
output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
output
.
next_token_logprobs
,
logits_metadata
logits_
output
.
next_token_logprobs
,
logits_metadata
)[
1
]
)[
1
]
return
output
return
sample_output
,
logits_
output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
75ce37f4
from
__future__
import
annotations
"""
"""
Copyright 2023-2024 SGLang Team
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -16,7 +18,7 @@ limitations under the License.
...
@@ -16,7 +18,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
...
@@ -42,6 +45,7 @@ class InputMetadata:
...
@@ -42,6 +45,7 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
sampling_info
:
SamplingBatchInfo
batch_size
:
int
batch_size
:
int
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
...
@@ -179,6 +183,7 @@ class InputMetadata:
...
@@ -179,6 +183,7 @@ class InputMetadata:
):
):
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
...
@@ -189,6 +194,8 @@ class InputMetadata:
...
@@ -189,6 +194,8 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
)
ret
.
sampling_info
.
prepare_penalties
()
ret
.
compute_positions
(
batch
)
ret
.
compute_positions
(
batch
)
ret
.
compute_extend_infos
(
batch
)
ret
.
compute_extend_infos
(
batch
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
75ce37f4
...
@@ -21,7 +21,7 @@ import importlib.resources
...
@@ -21,7 +21,7 @@ import importlib.resources
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
...
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPool
,
...
@@ -514,7 +516,11 @@ class ModelRunner:
...
@@ -514,7 +516,11 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
if
(
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
not
batch
.
sampling_info
.
has_bias
()
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
...
@@ -563,7 +569,9 @@ class ModelRunner:
...
@@ -563,7 +569,9 @@ class ModelRunner:
input_metadata
.
image_offsets
,
input_metadata
.
image_offsets
,
)
)
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
):
def
forward
(
self
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend_multi_modal
(
batch
)
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
elif
forward_mode
==
ForwardMode
.
DECODE
:
...
...
python/sglang/srt/models/chatglm.py
View file @
75ce37f4
...
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
...
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
LoraConfig
=
None
LoraConfig
=
None
...
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
def
sample
(
return
sample_output
,
logits_output
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/commandr.py
View file @
75ce37f4
...
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
...
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
positions
,
positions
,
input_metadata
,
input_metadata
,
)
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/dbrx.py
View file @
75ce37f4
...
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
...
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
...
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
...
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
[
expert_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek.py
View file @
75ce37f4
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
75ce37f4
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma.py
View file @
75ce37f4
...
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
...
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
(
sample_output
,
logits_output
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma2.py
View file @
75ce37f4
...
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -396,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -396,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -406,9 +408,11 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -406,9 +408,11 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_attention_sliding_window_size
(
self
):
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
return
get_attention_sliding_window_size
(
self
.
config
)
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
75ce37f4
...
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/grok.py
View file @
75ce37f4
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
# Monkey patch _prepare_weights to load pre-sharded weights
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/internlm2.py
View file @
75ce37f4
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/llama2.py
View file @
75ce37f4
...
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitProcessor
Output
,
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
Logit
s
Processor
,
LogitsProcessor
Output
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
...
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitProcessorOutput
:
)
->
Logit
s
ProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_module_name
(
self
,
name
):
def
get_module_name
(
self
,
name
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/llama_classification.py
View file @
75ce37f4
...
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
...
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
Logit
s
ProcessorOutput
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama2
import
LlamaModel
from
sglang.srt.models.llama2
import
LlamaModel
...
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
...
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
(
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)
(
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)
).
to
(
input_ids
.
device
)
).
to
(
input_ids
.
device
)
return
LogitProcessorOutput
(
return
Logit
s
ProcessorOutput
(
next_token_logits
=
scores
,
next_token_logits
=
scores
,
next_token_logprobs
=
scores
,
next_token_logprobs
=
scores
,
normalized_prompt_logprobs
=
scores
,
normalized_prompt_logprobs
=
scores
,
...
...
python/sglang/srt/models/minicpm.py
View file @
75ce37f4
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
else
:
lm_head_weight
=
self
.
lm_head
.
weight
lm_head_weight
=
self
.
lm_head
.
weight
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment