Commit 7e1d5e53 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.1

parents e3378b20 5f08050d
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "../../attention/dtype_float16.cuh" #include "../../attention/dtype_float16.cuh"
// #include "../../attention/dtype_bfloat16.cuh" // #include "../../attention/dtype_bfloat16.cuh"
#pragma once
namespace vllm { namespace vllm {
#ifdef ENABLE_FP8_E5M2 #ifdef ENABLE_FP8_E5M2
......
...@@ -94,3 +94,5 @@ class MockedClassDocumenter(autodoc.ClassDocumenter): ...@@ -94,3 +94,5 @@ class MockedClassDocumenter(autodoc.ClassDocumenter):
autodoc.ClassDocumenter = MockedClassDocumenter autodoc.ClassDocumenter = MockedClassDocumenter
navigation_with_keys = False
...@@ -12,7 +12,7 @@ Requirements ...@@ -12,7 +12,7 @@ Requirements
* OS: Linux * OS: Linux
* Python: 3.8 -- 3.11 * Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942) * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
* Pytorch 2.0.1/2.1.1/2.2 * Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) * ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
...@@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later ...@@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5` * `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
Their values can be passed in when running ``docker build`` with ``--build-arg`` options. Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
......
...@@ -67,3 +67,13 @@ You can also build and install vLLM from source: ...@@ -67,3 +67,13 @@ You can also build and install vLLM from source:
$ # Use `--ipc=host` to make sure the shared memory is large enough. $ # Use `--ipc=host` to make sure the shared memory is large enough.
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
.. note::
If you are developing the C++ backend of vLLM, consider building vLLM with
.. code-block:: console
$ python setup.py develop
since it will give you incremental builds. The downside is that this method
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
...@@ -82,12 +82,14 @@ Documentation ...@@ -82,12 +82,14 @@ Documentation
models/supported_models models/supported_models
models/adding_model models/adding_model
models/engine_args models/engine_args
models/lora
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Quantization :caption: Quantization
quantization/auto_awq quantization/auto_awq
quantization/fp8_e5m2_kv_cache
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
.. _lora:
Using LoRA adapters
===================
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with
.. code-block:: python
from huggingface_hub import snapshot_download
sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
Then we instantiate the base model and pass in the ``enable_lora=True`` flag:
.. code-block:: python
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True)
We can now submit the prompts and call ``llm.generate`` with the ``lora_request`` parameter. The first parameter
of ``LoRARequest`` is a human identifiable name, the second parameter is a globally unique ID for the adapter and
the third parameter is the path to the LoRA adapter.
.. code-block:: python
sampling_params = SamplingParams(
temperature=0,
max_tokens=256,
stop=["[/assistant]"]
)
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
]
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("sql_adapter", 1, sql_lora_path)
)
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
\ No newline at end of file
...@@ -47,9 +47,12 @@ Alongside each architecture, we include some popular models that use it. ...@@ -47,9 +47,12 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`InternLMForCausalLM` * - :code:`InternLMForCausalLM`
- InternLM - InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
* - :code:`InternLM2ForCausalLM`
- InternLM2
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc. - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
* - :code:`MistralForCausalLM` * - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct - Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
...@@ -74,9 +77,6 @@ Alongside each architecture, we include some popular models that use it. ...@@ -74,9 +77,6 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`StableLMEpochForCausalLM` * - :code:`StableLMEpochForCausalLM`
- StableLM - StableLM
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
* - :code:`YiForCausalLM`
- Yi
- :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model. Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
......
...@@ -9,6 +9,7 @@ The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot1 ...@@ -9,6 +9,7 @@ The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot1
Here is an example of how to enable this feature: Here is an example of how to enable this feature:
.. code-block:: python .. code-block:: python
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
......
...@@ -9,13 +9,13 @@ To install langchain, run ...@@ -9,13 +9,13 @@ To install langchain, run
.. code-block:: console .. code-block:: console
$ pip install langchain -q $ pip install langchain langchain_community -q
To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langchain``. To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langchain``.
.. code-block:: python .. code-block:: python
from langchain.llms import VLLM from langchain_community.llms import VLLM
llm = VLLM(model="mosaicml/mpt-7b", llm = VLLM(model="mosaicml/mpt-7b",
trust_remote_code=True, # mandatory for hf models trust_remote_code=True, # mandatory for hf models
...@@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha ...@@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha
print(llm("What is the capital of France ?")) print(llm("What is the capital of France ?"))
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/llms/vllm.ipynb>`_ for more details. Please refer to this `Tutorial <https://python.langchain.com/docs/integrations/llms/vllm>`_ for more details.
\ No newline at end of file
"""
This example shows how to use Ray Data for running offline batch inference
distributively on a multi-nodes cluster.
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""
from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np
import ray
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create a class to do batch inference.
class LLMPredictor:
def __init__(self):
# Create an LLM.
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# Generate texts from the prompts.
# The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information.
outputs = self.llm.generate(batch["text"], sampling_params)
prompt = []
generated_text = []
for output in outputs:
prompt.append(output.prompt)
generated_text.append(' '.join([o.text for o in output.outputs]))
return {
"prompt": prompt,
"generated_text": generated_text,
}
# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
# Apply batch inference for all input data.
ds = ds.map_batches(
LLMPredictor,
# Set the concurrency to the number of LLM instances.
concurrency=10,
# Specify the number of GPUs required per LLM instance.
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
# (i.e., `tensor_parallel_size`).
num_gpus=1,
# Specify the batch size for inference.
batch_size=32,
)
# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)
for output in outputs:
prompt = output["prompt"]
generated_text = output["generated_text"]
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")
# vLLM + Prometheus/Grafana
This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites.
Install:
- [`docker`](https://docs.docker.com/engine/install/)
- [`docker compose`](https://docs.docker.com/compose/install/linux/#install-using-the-repository)
### Launch
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
```bash
python3 -m vllm.entrypoints.openai.api_server \
--model mistralai/Mistral-7B-v0.1 \
--max-model-len 2048 \
--disable-log-requests
```
Launch Prometheus and Grafana servers with `docker compose`:
```bash
docker compose up
```
Submit some sample requests to the server:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
python3 ../../benchmarks/benchmark_serving.py \
--model mistralai/Mistral-7B-v0.1 \
--tokenizer mistralai/Mistral-7B-v0.1 \
--endpoint /v1/completions \
--dataset ShareGPT_V3_unfiltered_cleaned_split.json \
--request-rate 3.0
```
Navigating to [`http://localhost:8000/metrics`](http://localhost:8000/metrics) will show the raw Prometheus metrics being exposed by vLLM.
### Grafana Dashboard
Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the default username (`admin`) and password (`admin`).
#### Add Prometheus Data Source
Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus.
On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`.
Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.".
#### Import Dashboard
Navigate to [`http://localhost:3000/dashboard/import`](http://localhost:3000/dashboard/import), upload `grafana.json`, and select the `prometheus` datasource. You should see a screen that looks like the following:
![Grafana Dashboard Image](https://i.imgur.com/R2vH9VW.png)
# docker-compose.yaml
version: "3"
services:
prometheus:
image: prom/prometheus:latest
extra_hosts:
- "host.docker.internal:host-gateway" # allow a direct connection from container to the local machine
ports:
- "9090:9090" # the default port used by Prometheus
volumes:
- ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml # mount Prometheus config file
grafana:
image: grafana/grafana:latest
depends_on:
- prometheus
ports:
- "3000:3000" # the default port used by Grafana
This diff is collapsed.
# prometheus.yaml
global:
scrape_interval: 5s
evaluation_interval: 30s
scrape_configs:
- job_name: vllm
static_configs:
- targets:
- 'host.docker.internal:8000'
...@@ -11,3 +11,5 @@ uvicorn[standard] ...@@ -11,3 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette] aioprometheus[starlette]
pynvml == 11.5.0 pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
--- amd_hip_bf16.h 2024-02-06 18:28:58.268699142 +0000
+++ amd_hip_bf16.h.new 2024-02-06 18:28:31.988647133 +0000
@@ -90,10 +90,10 @@
#include "math_fwd.h" // ocml device functions
#if defined(__HIPCC_RTC__)
-#define __HOST_DEVICE__ __device__
+#define __HOST_DEVICE__ __device__ static
#else
#include <climits>
-#define __HOST_DEVICE__ __host__ __device__
+#define __HOST_DEVICE__ __host__ __device__ static inline
#endif
// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
...@@ -19,11 +19,16 @@ from pathlib import Path ...@@ -19,11 +19,16 @@ from pathlib import Path
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# The downside is that this method is deprecated, see
# https://github.com/pypa/setuptools/issues/917
MAIN_CUDA_VERSION = "12.1" MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures. # Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx926","gfx1030", "gfx1100"} ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx926","gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
...@@ -67,22 +72,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] ...@@ -67,22 +72,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
def get_amdgpu_offload_arch():
command = "/opt/dtk-23.10/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e
return None
def get_hipcc_rocm_version(): def get_hipcc_rocm_version():
# Run the hipcc --version command # Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'], result = subprocess.run(['hipcc', '--version'],
...@@ -142,6 +131,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ...@@ -142,6 +131,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version return nvcc_cuda_version
def get_pytorch_rocm_arch() -> Set[str]:
"""Get the cross section of Pytorch,and vllm supported gfx arches
ROCM can get the supported gfx architectures in one of two ways
Either through the PYTORCH_ROCM_ARCH env var, or output from
rocm_agent_enumerator.
In either case we can generate a list of supported arch's and
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
"""
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
.strip().replace("\n", ";")
arch_source_str = "rocm_agent_enumerator"
else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable"
# List are separated by ; or space.
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))
# Filter out the invalid architectures and print a warning.
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
f"None of the ROCM architectures in {arch_source_str} "
f"({env_arch_list}) is supported. "
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
if invalid_arch_list:
warnings.warn(
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
f"excluded from the {arch_source_str} output "
f"({env_arch_list}). Supported ROCM architectures are: "
f"{ROCM_SUPPORTED_ARCHS}.",
stacklevel=2)
return arch_list
def get_torch_arch_list() -> Set[str]: def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures, # TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
...@@ -166,22 +199,27 @@ def get_torch_arch_list() -> Set[str]: ...@@ -166,22 +199,27 @@ def get_torch_arch_list() -> Set[str]:
# If none of the specified architectures are valid, raise an error. # If none of the specified architectures are valid, raise an error.
if not arch_list: if not arch_list:
raise RuntimeError( raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. " f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.") f"Supported CUDA architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list: if invalid_arch_list:
warnings.warn( warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " f"Unsupported CUDA architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: " f"({env_arch_list}). Supported CUDA architectures are: "
f"{valid_archs}.", f"{valid_archs}.",
stacklevel=2) stacklevel=2)
return arch_list return arch_list
# First, check the TORCH_CUDA_ARCH_LIST environment variable. if _is_hip():
compute_capabilities = get_torch_arch_list() rocm_arches = get_pytorch_rocm_arch()
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
else:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities: if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine. # GPUs on the current machine.
...@@ -290,18 +328,6 @@ if _is_cuda(): ...@@ -290,18 +328,6 @@ if _is_cuda():
"nvcc": NVCC_FLAGS_PUNICA, "nvcc": NVCC_FLAGS_PUNICA,
}, },
)) ))
elif _is_hip():
amd_archs = os.getenv("GPU_ARCHS")
if amd_archs is None:
# amd_archs = get_amdgpu_offload_arch()
amd_archs = "gfx906;gfx926"
for arch in amd_archs.split(";"):
if arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {arch}")
NVCC_FLAGS += [f"--offload-arch={arch}"]
elif _is_neuron(): elif _is_neuron():
neuronxcc_version = get_neuronxcc_version() neuronxcc_version = get_neuronxcc_version()
...@@ -322,6 +348,17 @@ if _is_cuda(): ...@@ -322,6 +348,17 @@ if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/custom_all_reduce.cu") vllm_extension_sources.append("csrc/custom_all_reduce.cu")
# Add MoE kernels.
ext_modules.append(
CUDAExtension(
name="vllm._moe_C",
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
))
if not _is_neuron(): if not _is_neuron():
vllm_extension = CUDAExtension( vllm_extension = CUDAExtension(
name="vllm._C", name="vllm._C",
...@@ -393,8 +430,8 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -393,8 +430,8 @@ def get_version_add(sha: Optional[str] = None) -> str:
version += ".torch" + torch.__version__[:3] version += ".torch" + torch.__version__[:3]
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write("__version__='0.3.0'\n") file.write("__version__='0.3.1'\n")
file.write("__dcu_version__='0.3.0+{}'\n".format(version)) file.write("__dcu_version__='0.3.1+{}'\n".format(version))
file.close() file.close()
...@@ -411,9 +448,9 @@ def get_vllm_version() -> str: ...@@ -411,9 +448,9 @@ def get_vllm_version() -> str:
if _is_hip(): if _is_hip():
# Get the HIP version # Get the HIP version
# hipcc_version = get_hipcc_rocm_version() hipcc_version = get_hipcc_rocm_version()
# if hipcc_version != MAIN_CUDA_VERSION: if hipcc_version != MAIN_CUDA_VERSION:
# rocm_version_str = hipcc_version.replace(".", "")[:3] rocm_version_str = hipcc_version.replace(".", "")[:3]
# version += f"+rocm{rocm_version_str}" # version += f"+rocm{rocm_version_str}"
version = get_version() version = get_version()
elif _is_neuron(): elif _is_neuron():
......
import torch
# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}
def get_default_atol(output) -> float:
return default_atol[output.dtype]
def get_default_rtol(output) -> float:
return default_rtol[output.dtype]
...@@ -2,77 +2,92 @@ import pytest ...@@ -2,77 +2,92 @@ import pytest
import torch import torch
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
from allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_silu_and_mul( def test_silu_and_mul(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
gpu_id = f"cuda:{device}" torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = SiluAndMul() layer = SiluAndMul()
out = layer(x) out = layer(x)
ref_out = layer._forward(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_gelu_new( def test_gelu_new(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
gpu_id = f"cuda:{device}" torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = NewGELU() layer = NewGELU()
out = layer(x) out = layer(x)
ref_out = layer._forward(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gelu_fast( def test_gelu_fast(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
gpu_id = f"cuda:{device}" torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = FastGELU() layer = FastGELU()
out = layer(x) out = layer(x)
ref_out = layer._forward(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
...@@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask ...@@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import ops, cache_ops from vllm._C import ops, cache_ops
from vllm.utils import get_max_shared_memory_bytes from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import is_hip
from allclose_default import get_default_atol, get_default_rtol
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
...@@ -17,17 +19,25 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 ...@@ -17,17 +19,25 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens. # Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 256
] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def ref_masked_attention( def ref_masked_attention(
...@@ -91,7 +101,7 @@ def ref_single_query_cached_kv_attention( ...@@ -91,7 +101,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None alibi_bias = None
if alibi_slopes is not None: if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel. # Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device=query.device).int() position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float() alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1) 1, 1, -1)
...@@ -110,7 +120,7 @@ def ref_single_query_cached_kv_attention( ...@@ -110,7 +120,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_paged_attention( def test_paged_attention(
kv_cache_factory, kv_cache_factory,
version: str, version: str,
...@@ -122,33 +132,28 @@ def test_paged_attention( ...@@ -122,33 +132,28 @@ def test_paged_attention(
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str, kv_cache_dtype: str,
seed: int, seed: int,
device: int, device: str,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
gpu_id = f"cuda:{device}" torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
num_query_heads,
head_size,
dtype=dtype,
device=gpu_id)
query.uniform_(-scale, scale) query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None alibi_slopes = None
if use_alibi: if use_alibi:
alibi_slopes = torch.randn(num_query_heads, alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
dtype=torch.float,
device=gpu_id)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) context_lens = torch.tensor(context_lens, dtype=torch.int)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
...@@ -159,13 +164,13 @@ def test_paged_attention( ...@@ -159,13 +164,13 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq) for _ in range(max_num_blocks_per_seq)
] ]
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, num_kv_heads, head_size,
kv_cache_dtype, dtype, seed, kv_cache_dtype, dtype, seed,
gpu_id) device)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel. # Call the paged attention kernel.
...@@ -193,12 +198,10 @@ def test_paged_attention( ...@@ -193,12 +198,10 @@ def test_paged_attention(
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size), size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
device=output.device,
) )
exp_sums = torch.empty( exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions), size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32, dtype=torch.float32,
device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -229,14 +232,14 @@ def test_paged_attention( ...@@ -229,14 +232,14 @@ def test_paged_attention(
block_size, x) block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape, dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype, dtype=dtype,
device=gpu_id) device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape, dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype, dtype=dtype,
device=gpu_id) device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache value_cache = dequantized_value_cache
...@@ -256,9 +259,11 @@ def test_paged_attention( ...@@ -256,9 +259,11 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8_e5m2": if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-2, 1e-5 atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
...@@ -283,7 +288,7 @@ def ref_multi_query_kv_attention( ...@@ -283,7 +288,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device=query.device) attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention( ref_output = ref_masked_attention(
query[start_idx:end_idx], query[start_idx:end_idx],
...@@ -303,7 +308,7 @@ def ref_multi_query_kv_attention( ...@@ -303,7 +308,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
num_seqs: int, num_seqs: int,
...@@ -311,12 +316,13 @@ def test_multi_query_kv_attention( ...@@ -311,12 +316,13 @@ def test_multi_query_kv_attention(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int, device: str,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
gpu_id = f"cuda:{device}" torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use # As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here. # a smaller MAX_SEQ_LEN here.
...@@ -329,8 +335,7 @@ def test_multi_query_kv_attention( ...@@ -329,8 +335,7 @@ def test_multi_query_kv_attention(
qkv = torch.empty(num_tokens, qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads, num_query_heads + 2 * num_kv_heads,
head_size, head_size,
dtype=dtype, dtype=dtype)
device=gpu_id)
qkv.uniform_(-scale, scale) qkv.uniform_(-scale, scale)
query, key, value = qkv.split( query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1) [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
...@@ -362,4 +367,6 @@ def test_multi_query_kv_attention( ...@@ -362,4 +367,6 @@ def test_multi_query_kv_attention(
scale, scale,
dtype, dtype,
) )
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment