Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0130223b
Unverified
Commit
0130223b
authored
Feb 02, 2026
by
Vasiliy Kuznetsov
Committed by
GitHub
Feb 02, 2026
Browse files
fix memory for online fp8 quantization with streaming weight load (#31914)
Signed-off-by:
vasiliy
<
vasiliy@fb.com
>
parent
5d1aef30
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
282 additions
and
37 deletions
+282
-37
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+96
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+119
-4
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+13
-0
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+1
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+53
-32
No files found.
tests/quantization/test_fp8.py
View file @
0130223b
...
@@ -5,7 +5,10 @@
...
@@ -5,7 +5,10 @@
Run `pytest tests/quantization/test_fp8.py --forked`.
Run `pytest tests/quantization/test_fp8.py --forked`.
"""
"""
import
logging
import
pytest
import
pytest
import
regex
as
re
import
torch
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
...
@@ -195,6 +198,99 @@ def test_online_quantization(
...
@@ -195,6 +198,99 @@ def test_online_quantization(
print
(
outputs
[
0
][
1
])
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
)
def
test_online_quant_peak_mem
(
vllm_runner
,
caplog_mp_spawn
,
monkeypatch
,
)
->
None
:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name
=
"allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
with
(
caplog_mp_spawn
(
logging
.
DEBUG
)
as
log_holder
,
vllm_runner
(
model_name
,
quantization
=
"fp8"
,
enforce_eager
=
True
,
)
as
llm
,
):
outputs
=
llm
.
generate_greedy
([
"The future of AI is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
log_text
=
log_holder
.
text
# Parse memory usage from captured logs
model_memory_gib
=
None
peak_memory_gib
=
None
for
line
in
log_text
.
splitlines
():
if
model_memory_gib
is
None
:
match
=
re
.
search
(
r
"Model loading took ([\d.]+) GiB memory"
,
line
)
if
match
:
model_memory_gib
=
float
(
match
.
group
(
1
))
if
peak_memory_gib
is
None
:
match
=
re
.
search
(
r
"Peak GPU memory after loading weights: ([\d.]+) GiB"
,
line
)
if
match
:
peak_memory_gib
=
float
(
match
.
group
(
1
))
assert
model_memory_gib
is
not
None
,
"Could not find model loading memory log"
assert
peak_memory_gib
is
not
None
,
"Could not find peak memory log"
print
(
f
"GPU memory used after loading weights:
{
model_memory_gib
}
GiB"
)
print
(
f
"Peak GPU memory usage while loading weights:
{
peak_memory_gib
}
GiB"
)
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib
=
6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib
=
expected_model_memory_gib
*
1.4
assert
model_memory_gib
<
expected_model_memory_gib
,
(
f
"
{
model_memory_gib
=
}
higher than
{
expected_model_memory_gib
}
"
)
assert
peak_memory_gib
<
expected_peak_memory_gib
,
(
f
"
{
peak_memory_gib
=
}
higher than
{
expected_peak_memory_gib
}
"
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
)
def
test_online_quant_load_format_dummy
(
vllm_runner
,
monkeypatch
,
caplog
,
)
->
None
:
with
vllm_runner
(
"ibm-granite/granite-3.0-1b-a400m-base"
,
quantization
=
"fp8"
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
)
as
llm
:
outputs
=
llm
.
generate_greedy
([
"The future of AI is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
reason
=
"FP8 is not supported on this GPU type."
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0130223b
...
@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
initialize_single_dummy_weight
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
...
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
...
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
return
out
return
out
def
_copy_missing_attrs
(
old
:
torch
.
Tensor
,
new
:
torch
.
Tensor
)
->
None
:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs
=
set
(
dir
(
new
))
attrs_to_set
=
{}
for
attr
in
dir
(
old
):
if
attr
not
in
new_attrs
:
attrs_to_set
[
attr
]
=
getattr
(
old
,
attr
)
set_weight_attrs
(
new
,
attrs_to_set
)
class
Fp8LinearMethod
(
LinearMethodBase
):
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
...
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
...
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
=
0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param
=
layer
.
weight
# load the current weight chunk
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
with
copy_numel_counter
:
...
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
...
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if
layer
.
_loaded_numel
==
target_loaded_numel
:
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# Prevent the usual `process_weights_after_loading` call from doing
# anything
# anything
layer
.
_already_called_process_weights_after_loading
=
True
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return
res
return
res
weight
=
ModelWeightParameter
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
output_size_per_partition
,
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
input_dim
=
1
,
input_dim
=
1
,
output_dim
=
0
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
weight_loader
=
patched_weight_loader
,
)
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
weight
.
device
==
torch
.
device
(
"meta"
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
layer
.
weight
.
weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_single_dummy_weight
(
layer
.
weight
)
# TODO(future): support block_quant in online quant path
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
assert
not
self
.
block_quant
...
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
=
0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer
.
_w13_weight_orig_id
=
id
(
layer
.
w13_weight
)
layer
.
_w2_weight_orig_id
=
id
(
layer
.
w2_weight
)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if
id
(
param
)
==
layer
.
_w13_weight_orig_id
:
param
=
layer
.
w13_weight
elif
id
(
param
)
==
layer
.
_w2_weight_orig_id
:
param
=
layer
.
w2_weight
# load the current weight chunk
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
with
copy_numel_counter
:
...
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
layer
.
_loaded_numel
==
target_loaded_numel
:
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call
# Prevent the usual `process_weights_after_loading` call
# from doing anything
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return
res
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
...
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
hidden_size
,
hidden_size
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition
,
intermediate_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
# WEIGHT_SCALES
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# Allocate 2 scales for w1 and w3 respectively.
...
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
w13_weight
.
device
==
torch
.
device
(
"meta"
):
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
{
"weight_loader"
:
layer
.
w13_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
initialize_single_dummy_weight
(
layer
.
w13_weight
)
if
layer
.
w2_weight
.
device
==
torch
.
device
(
"meta"
):
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
{
"weight_loader"
:
layer
.
w2_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
initialize_single_dummy_weight
(
layer
.
w2_weight
)
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
0130223b
...
@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import (
...
@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import (
initialize_model
,
initialize_model
,
process_weights_after_loading
,
process_weights_after_loading
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.utils.mem_utils
import
format_gib
from
vllm.utils.torch_utils
import
set_default_torch_dtype
from
vllm.utils.torch_utils
import
set_default_torch_dtype
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -56,6 +58,17 @@ class BaseModelLoader(ABC):
...
@@ -56,6 +58,17 @@ class BaseModelLoader(ABC):
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
# Quantization does not happen in `load_weights` but after it
# Quantization does not happen in `load_weights` but after it
self
.
load_weights
(
model
,
model_config
)
self
.
load_weights
(
model
,
model_config
)
# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if
current_platform
.
is_cuda
():
peak_memory
=
torch
.
cuda
.
max_memory_allocated
()
logger
.
debug_once
(
"Peak GPU memory after loading weights: %s GiB"
,
format_gib
(
peak_memory
),
scope
=
"local"
,
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
return
model
.
eval
()
...
...
vllm/model_executor/model_loader/dummy_loader.py
View file @
0130223b
...
@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader):
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
,
model_config
)
vllm/model_executor/model_loader/weight_utils.py
View file @
0130223b
...
@@ -1059,6 +1059,7 @@ def composed_weight_loader(
...
@@ -1059,6 +1059,7 @@ def composed_weight_loader(
def
initialize_dummy_weights
(
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
,
low
:
float
=
-
1e-3
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
...
@@ -1075,41 +1076,61 @@ def initialize_dummy_weights(
...
@@ -1075,41 +1076,61 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
the parameter's number of elements and its data type.
"""
"""
# TODO(future PR): make the check below more generic as more online
# quant backends are added
is_fp8_py_quant
=
model_config
.
quantization
==
"fp8"
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
if
is_fp8_py_quant
and
param
.
device
==
torch
.
device
(
"meta"
):
if
current_platform
.
is_tpu
():
# for fp8.py's online quantization, dummy weight init will happen
generator
=
torch
.
Generator
(
device
=
"cpu"
)
# in `process_weights_after_loading`.
generator
.
manual_seed
(
seed
)
# TODO(future PR): consider refactoring dummy model init to compose
# Note: The param.uniform_ function cannot be used in this
# better with online quantization
# context because it demands more TPU HBM than directly copying
continue
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param
.
copy_
(
(
high
-
low
)
*
torch
.
rand
(
param
.
shape
,
generator
=
generator
,
dtype
=
param
.
dtype
,
layout
=
param
.
layout
,
requires_grad
=
param
.
requires_grad
,
device
=
"cpu"
,
)
+
low
)
torch
.
_sync
(
param
)
continue
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
initialize_single_dummy_weight
(
param
,
low
,
high
,
seed
)
def
initialize_single_dummy_weight
(
param
:
torch
.
Tensor
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
)
->
None
:
if
torch
.
is_floating_point
(
param
):
if
current_platform
.
is_tpu
():
generator
=
torch
.
Generator
(
device
=
"cpu"
)
generator
.
manual_seed
(
seed
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
# Note: The param.uniform_ function cannot be used in this
# uniform_ doesn't support < 16-bit datatypes (FP8)
# context because it demands more TPU HBM than directly copying
dtype
=
param
.
data
.
dtype
# from a CPU tensor.
tmp_param
=
param
.
data
.
to
(
torch
.
float16
)
# Note: We avoid using torch.rank_like as it doesn't currently
tmp_param
=
tmp_param
.
uniform_
(
low
,
high
,
generator
=
generator
).
to
(
dtype
)
# support the generator argument.
param
.
data
.
copy_
(
tmp_param
)
param
.
copy_
(
else
:
(
high
-
low
)
param
.
uniform_
(
low
,
high
,
generator
=
generator
)
*
torch
.
rand
(
param
.
shape
,
generator
=
generator
,
dtype
=
param
.
dtype
,
layout
=
param
.
layout
,
requires_grad
=
param
.
requires_grad
,
device
=
"cpu"
,
)
+
low
)
torch
.
_sync
(
param
)
return
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype
=
param
.
data
.
dtype
tmp_param
=
param
.
data
.
to
(
torch
.
float16
)
tmp_param
=
tmp_param
.
uniform_
(
low
,
high
,
generator
=
generator
).
to
(
dtype
)
param
.
data
.
copy_
(
tmp_param
)
else
:
param
.
uniform_
(
low
,
high
,
generator
=
generator
)
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
str
|
None
:
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
str
|
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment