"vscode:/vscode.git/clone" did not exist on "5e77e9b1afe3eb44b2013bb923eb72eb7aa8c139"
Unverified Commit 23993a79 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix][TPU] Do not use torch.Generator for TPUs (#6981)

parent 1d2e7fb7
...@@ -22,6 +22,7 @@ from vllm.logger import init_logger ...@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -490,6 +491,11 @@ def initialize_dummy_weights( ...@@ -490,6 +491,11 @@ def initialize_dummy_weights(
""" """
for param in model.state_dict().values(): for param in model.state_dict().values():
if torch.is_floating_point(param): if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
continue
generator = torch.Generator(device=param.data.device) generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed) generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16: if torch.finfo(param.data.dtype).bits < 16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment