Commit 217ee621 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev

parents f0021a4d 3f78216a
......@@ -40,7 +40,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures.
# Supported hcu architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx926;gfx928;gfx936")
#
......
......@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| :------: | :------: | :------: | :------: |
| :------: | :------: | :------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes |
......@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| BloomForCausalLM | BLOOM | Yes | No | - |
| InternLMForCausalLM | InternLM | Yes | No | - |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - |
| FalconForCausalLM | falcon | Yes | No | - |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - |
......
......@@ -184,7 +184,7 @@ if __name__ == '__main__':
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......@@ -193,7 +193,7 @@ if __name__ == '__main__':
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--profile',
......
......@@ -243,7 +243,7 @@ if __name__ == "__main__":
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......@@ -252,7 +252,7 @@ if __name__ == "__main__":
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
"--device",
......
......@@ -522,7 +522,7 @@ if __name__ == "__main__":
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......@@ -531,7 +531,7 @@ if __name__ == "__main__":
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
......
......@@ -268,7 +268,7 @@ if __name__ == '__main__':
default="auto",
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
"ROCm (hcu) supports fp8 (=fp8_e4m3)")
args = parser.parse_args()
print(args)
......
......@@ -402,7 +402,7 @@ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an hcu which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
......
......@@ -3,7 +3,7 @@
Installation with ROCm
======================
vLLM supports AMD GPUs with ROCm 6.2.
vLLM supports hcus with ROCm 6.2.
Requirements
------------
......
......@@ -15,7 +15,7 @@ The table below shows the compatibility of various quantization implementations
- Ampere
- Ada
- Hopper
- AMD GPU
- hcu
- Intel GPU
- x86 CPU
- AWS Inferentia
......
# FP8 KV Cache
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (hcu) platforms.
## Prerequisites
......@@ -41,7 +41,7 @@ Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {
KV Scale Extraction Example
optional arguments:
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (hcu).
Optional arguments:
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
......@@ -87,8 +87,8 @@ optional arguments:
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
--enforce-eager enforce eager execution
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is instead supported ```for common inference criteria.
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is instead supported for common inference criteria.
```
```
Example:
......
......@@ -325,12 +325,12 @@ if __name__ == "__main__":
"use by vLLM (pass this file to the appropriate "
"runtime typically using the argument "
"--quantization-param-path <filename>). This is only used "
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
"if the KV cache dtype is FP8 and on ROCm (hcu).")
parser.add_argument(
"--quantized-model",
help="Specify the directory containing a single quantized HF model. "
"It is expected that the quantization format is FP8_E4M3, for use "
"on ROCm (AMD GPU).",
"on ROCm (hcu).",
required=True)
parser.add_argument(
"--load_format",
......
......@@ -470,7 +470,7 @@ if __name__ == "__main__":
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......@@ -479,7 +479,7 @@ if __name__ == "__main__":
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
......
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
if __name__ == '__main__':
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
\ No newline at end of file
# Common dependencies
-r requirements-common.txt
# Dependencies for AMD GPUs
# Dependencies for hcus
awscli
boto3
botocore
......
......@@ -5,6 +5,8 @@ pytest-forked
pytest-asyncio
pytest-rerunfailures
pytest-shard
pytest-html
pytest-timeout
# testing utils
awscli
......
......@@ -399,9 +399,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
try:
__version__ = "0.6.2"
__version_tuple__ = (0, 6, 2)
__dcu_version__ = f'0.6.2+{version}'
__hcu_version__ = f'0.6.2+{version}'
from vllm.version import __version__, __version_tuple__, __dcu_version__
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
......@@ -420,7 +420,7 @@ def get_version():
version_file = 'vllm/version.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__dcu_version__']
return locals()['__hcu_version__']
def get_vllm_version() -> str:
......
......@@ -8,6 +8,8 @@ from collections import UserList
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
TypedDict, TypeVar, Union)
import pytest
import pytest_html
import numpy as np
import pytest
......@@ -898,3 +900,21 @@ def dummy_opt_path():
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path
# 定义一个 pytest 钩子,在测试后生成报告
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
# 获取测试结果
outcome = yield
result = outcome.get_result()
# 如果测试失败并且有浏览器实例,添加截图
if result.when == "call" and result.failed:
if hasattr(item, "funcargs") and "browser" in item.funcargs:
browser = item.funcargs["browser"]
screenshot_path = "screenshot.png" # 设置截图路径
browser.save_screenshot(screenshot_path)
# 如果测试结果有 extra 属性,则添加截图
if hasattr(result, "extra"):
result.extra.append(pytest_html.extras.image(screenshot_path))
......@@ -46,7 +46,8 @@ def llm():
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
# return snapshot_download(repo_id=LORA_NAME)
return LORA_NAME
@pytest.mark.skip_global_cleanup
......
......@@ -3,13 +3,13 @@ from typing import Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
from .utils import torch_version
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
......@@ -49,14 +49,21 @@ def test_act_and_mul(
fn = torch.ops._C.gelu_tanh_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
if torch_version.startswith("2.3"):
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
......@@ -83,10 +90,20 @@ def test_activation(
fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
out = torch.empty_like(x)
opcheck(fn, (out, x))
if torch_version.startswith("2.3"):
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
out = torch.empty_like(x)
opcheck(fn, (out, x))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
......@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
from .utils import torch_version
if not is_hip():
from xformers import ops as xops
......@@ -186,49 +186,9 @@ def test_paged_attention(
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
if version == "v2":
ops.paged_attention_v2(
if torch_version.startswith("2.3"):
ops.paged_attention_v1(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
......@@ -243,21 +203,10 @@ def test_paged_attention(
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
ops.paged_attention_rocm(
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_v1(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
......@@ -273,13 +222,133 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
if version == "v2":
if torch_version.startswith("2.3"):
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
else:
if torch_version.startswith("2.3"):
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
else:
raise AssertionError(f"Unknown version: {version}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment