Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23993a79
Unverified
Commit
23993a79
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
[Bugfix][TPU] Do not use torch.Generator for TPUs (#6981)
parent
1d2e7fb7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
0 deletions
+6
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+6
-0
No files found.
vllm/model_executor/model_loader/weight_utils.py
View file @
23993a79
...
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -490,6 +491,11 @@ def initialize_dummy_weights(
"""
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
if
current_platform
.
is_tpu
():
# XLA device does not support torch.Generator()
param
.
uniform_
(
low
,
high
)
continue
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment