Unverified Commit 2048c4e3 authored by Jerry Zhang's avatar Jerry Zhang Committed by GitHub
Browse files

[torchao] Support quantization configs using module swap (#21982)


Signed-off-by: default avatarJerry Zhang <jerryzh168@gmail.com>
parent d1336018
...@@ -507,6 +507,10 @@ steps: ...@@ -507,6 +507,10 @@ steps:
commands: commands:
# temporary install here since we need nightly, will move to requirements/test.in # temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release, and pin a working version of torchao nightly here # after torchao 0.12 release, and pin a working version of torchao nightly here
# since torchao nightly is only compatible with torch nightly currently
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
......
...@@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): ...@@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
print(output) print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now")
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2"
"-0.14.0.dev")
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor, ...@@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor,
from torchao.quantization import quantize_ from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
""" """
Avoid real weight allocation for faster load, since we will Avoid real weight allocation for faster load, since we will
end up setting it to param. end up setting it to param.
""" """
with torch.device("meta"): with torch.device("meta"):
dummy_linear = torch.nn.Linear(param.shape[1], # linear can't be top level module since quantize_ is inplace
param.shape[0], # while some of our configs need to do module swap, and only non-top
bias=False) # level modules support module swap
dummy_linear = torch.nn.Sequential(
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
dummy_linear.weight = param dummy_linear[0].weight = param
quantize_(dummy_linear, torchao_config) quantize_(dummy_linear, torchao_config)
return dummy_linear.weight return dummy_linear[0].weight
class TorchAOLinearMethod(LinearMethodBase): class TorchAOLinearMethod(LinearMethodBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment