Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11f3cca6
Unverified
Commit
11f3cca6
authored
Jan 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 20, 2024
Browse files
Fix select (#64)
parent
ca13f3b8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
2 deletions
+12
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-1
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+10
-1
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
11f3cca6
...
@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
...
@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
def
compute_normalized_logprobs
(
all_logprobs
,
len_add_1
,
input_ids
):
def
compute_normalized_logprobs
(
all_logprobs
,
len_add_1
,
input_ids
):
# assert all_logprobs.shape[0] ==
torch.sum(len_add_1) == input_ids.shape[0]
# assert all_logprobs.shape[0] ==
input_ids.shape[0] == torch.sum(len_add_1)
logprobs
=
torch
.
zeros
(
logprobs
=
torch
.
zeros
(
(
all_logprobs
.
shape
[
0
]
-
len_add_1
.
shape
[
0
]),
dtype
=
torch
.
float32
,
device
=
"cuda"
(
all_logprobs
.
shape
[
0
]
-
len_add_1
.
shape
[
0
]),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
...
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
...
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
end
=
torch
.
cumsum
(
len_add_1
.
sub_
(
1
),
dim
=
0
)
end
=
torch
.
cumsum
(
len_add_1
.
sub_
(
1
),
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
end
[:
-
1
]),
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
end
[:
-
1
]),
0
)
end
.
sub_
(
1
)
end
.
sub_
(
1
)
torch
.
cuda
.
synchronize
()
sum_logp
=
cumsum
[
end
]
-
cumsum
[
start
]
+
logprobs
[
start
]
sum_logp
=
cumsum
[
end
]
-
cumsum
[
start
]
+
logprobs
[
start
]
res
=
sum_logp
/
len_add_1
res
=
sum_logp
/
len_add_1
return
res
return
res
python/sglang/srt/managers/router/model_runner.py
View file @
11f3cca6
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
List
from
typing
import
List
import
logging
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
...
@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
logger
=
logging
.
getLogger
(
"model_runner"
)
# for model_mode
# for model_mode
global_model_mode
:
List
[
str
]
=
[]
global_model_mode
:
List
[
str
]
=
[]
...
@@ -257,6 +262,8 @@ class ModelRunner:
...
@@ -257,6 +262,8 @@ class ModelRunner:
if
model_class
is
None
:
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
logger
.
info
(
"load weight begin."
)
# Load weights
# Load weights
linear_method
=
None
linear_method
=
None
with
_set_default_torch_dtype
(
torch
.
float16
):
with
_set_default_torch_dtype
(
torch
.
float16
):
...
@@ -267,7 +274,7 @@ class ModelRunner:
...
@@ -267,7 +274,7 @@ class ModelRunner:
if
hf_quant_config
is
not
None
:
if
hf_quant_config
is
not
None
:
# TODO: config quantization awq etc
# TODO: config quantization awq etc
quant_config
=
AWQConfig
.
from_config
(
hf_quant_config
)
quant_config
=
AWQConfig
.
from_config
(
hf_quant_config
)
print
(
f
"quant_config:
{
quant_config
}
"
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
linear_method
=
quant_config
.
get_linear_method
()
model
=
model_class
(
model
=
model_class
(
config
=
self
.
model_config
.
hf_config
,
linear_method
=
linear_method
config
=
self
.
model_config
.
hf_config
,
linear_method
=
linear_method
...
@@ -280,6 +287,8 @@ class ModelRunner:
...
@@ -280,6 +287,8 @@ class ModelRunner:
)
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
logger
.
info
(
"load weight end."
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment