Unverified Commit 7ecee343 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Kernel][RFC] Refactor the punica kernel based on Triton (#5036)

parent 7eb0cb4a
......@@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt
# Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1
# Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
# Build
......
......@@ -223,61 +223,7 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
#
# _punica_C extension
#
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/torch_bindings.cpp")
#
# Copy GPU compilation flags+update for punica
#
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
#
# Filter out CUDA architectures < 8.0 for punica.
#
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
set(VLLM_PUNICA_GPU_ARCHES)
foreach(ARCH ${VLLM_GPU_ARCHES})
string_to_ver(CODE_VER ${ARCH})
if (CODE_VER GREATER_EQUAL 8.0)
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
define_gpu_extension_target(
_punica_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
endif()
#
# Add the `default` target which detects which extensions should be
......@@ -301,12 +247,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
# there are supported target arches.
if (VLLM_PUNICA_GPU_ARCHES AND
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
message(STATUS "Enabling punica extension.")
add_dependencies(default _punica_C)
endif()
endif()
......@@ -88,8 +88,6 @@ ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc
ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
ARG buildkite_commit
ENV BUILDKITE_COMMIT=${buildkite_commit}
......
......@@ -131,8 +131,7 @@ COPY . .
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning
......
Contains code from https://github.com/punica-ai/punica
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache-2.0
* third_party/nvbench (with LLVM exception)
* third_party/flashinfer
BSD-3-Clause:
* third_party/cutlass
\ No newline at end of file
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 896) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1216) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1664) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2240) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2368) \
f(in_T, out_T, W_T, narrow, 2432) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 3712) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4480) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 4736) \
f(in_T, out_T, W_T, narrow, 4864) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5888) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7424) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 8960) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 9472) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 11264) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 14784) \
f(in_T, out_T, W_T, narrow, 14848) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 18944) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 22528) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 29568) \
f(in_T, out_T, W_T, narrow, 29696) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 49408) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 896, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1216, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1664, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2240, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2368, narrow) \
f(in_T, out_T, W_T, 2432, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 3712, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4480, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 4736, narrow) \
f(in_T, out_T, W_T, 4864, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 5888, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 7424, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 8960, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 9472, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 11264, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 14784, narrow) \
f(in_T, out_T, W_T, 14848, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 18944, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 22528, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 29568, narrow) \
f(in_T, out_T, W_T, 29696, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 49408, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
#pragma once
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline>
#endif
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
if output_dtype == "fp32":
# LoRA A matrix.
if input_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif input_dtype == "fp32":
# LoRA B matrix.
if output_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif not (input_dtype == output_dtype == weight_dtype):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__
template <typename float_t, size_t vec_size>
struct vec_t {
FLASHINFER_INLINE float_t &operator[](size_t i);
FLASHINFER_INLINE const float_t &operator[](size_t i) const;
FLASHINFER_INLINE void fill(float_t val);
FLASHINFER_INLINE void load(const float_t *ptr);
FLASHINFER_INLINE void store(float_t *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src);
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr);
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const;
FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src);
};
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<src_float_t, vec_size> &src,
vec_t<tgt_float_t, vec_size> &dst) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = tgt_float_t(src[i]);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr,
vec_t<tgt_float_t, vec_size> &dst) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
dst.load(src_ptr);
} else {
vec_t<src_float_t, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_store_impl(const vec_t<src_float_t, vec_size> &src,
tgt_float_t *dst_ptr) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
src.store(dst_ptr);
} else {
vec_t<tgt_float_t, vec_size> tmp;
tmp.cast_from(src);
tmp.store(dst_ptr);
}
}
#ifdef FLASHINFER_USE_FP8
/******************* vec_t<__nv_fp8_e4m3> *******************/
// __nv_fp8_e4m3 x 1
template <>
struct vec_t<__nv_fp8_e4m3, 1> {
__nv_fp8_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(
__nv_fp8_e4m3 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*dst = *src;
}
// __nv_fp8_e4m3 x 2
template <>
struct vec_t<__nv_fp8_e4m3, 2> {
__nv_fp8x2_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x2_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x2_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src);
}
// __nv_fp8_e4m3 x 4
template <>
struct vec_t<__nv_fp8_e4m3, 4> {
__nv_fp8x4_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x4_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x4_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src);
}
// __nv_fp8_e4m3 x 8
template <>
struct vec_t<__nv_fp8_e4m3, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(
__nv_fp8_e4m3 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src);
}
// __nv_fp8_e4m3 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e4m3, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<__nv_fp8_e5m2> *******************/
// __nv_fp8_e5m2 x 1
template <>
struct vec_t<__nv_fp8_e5m2, 1> {
__nv_fp8_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(
__nv_fp8_e5m2 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*dst = *src;
}
// __nv_fp8_e5m2 x 2
template <>
struct vec_t<__nv_fp8_e5m2, 2> {
__nv_fp8x2_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x2_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x2_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src);
}
// __nv_fp8_e5m2 x 4
template <>
struct vec_t<__nv_fp8_e5m2, 4> {
__nv_fp8x4_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x4_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x4_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src);
}
// __nv_fp8_e5m2 x 8
template <>
struct vec_t<__nv_fp8_e5m2, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(
__nv_fp8_e5m2 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src);
}
// __nv_fp8_e5m2 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e5m2, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
#endif
/******************* vec_t<half> *******************/
// half x 1
template <>
struct vec_t<half, 1> {
half data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
FLASHINFER_INLINE void vec_t<half, 1>::load(const half *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<half, 1>::store(half *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<half, 1>::memcpy(half *dst, const half *src) {
*dst = *src;
}
// half x 2
template <>
struct vec_t<half, 2> {
half2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 2>::fill(half val) {
data = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 2>::load(const half *ptr) {
data = *((half2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 2>::store(half *ptr) const {
*((half2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 2>::memcpy(half *dst, const half *src) {
*((half2 *)dst) = *((half2 *)src);
}
// half x 4
template <>
struct vec_t<half, 4> {
uint2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 4>::fill(half val) {
*(half2 *)(&data.x) = make_half2(val, val);
*(half2 *)(&data.y) = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 4>::load(const half *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 4>::store(half *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 4>::memcpy(half *dst, const half *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// half x 8 or more
template <size_t vec_size>
struct vec_t<half, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)data)[i];
}
FLASHINFER_INLINE void fill(half val) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
*(half2 *)(&(data[i].x)) = make_half2(val, val);
*(half2 *)(&(data[i].y)) = make_half2(val, val);
*(half2 *)(&(data[i].z)) = make_half2(val, val);
*(half2 *)(&(data[i].w)) = make_half2(val, val);
}
}
FLASHINFER_INLINE void load(const half *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(half *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<nv_bfloat16> *******************/
// nv_bfloat16 x 1
template <>
struct vec_t<nv_bfloat16, 1> {
nv_bfloat16 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*dst = *src;
}
// nv_bfloat16 x 2
template <>
struct vec_t<nv_bfloat16, 2> {
nv_bfloat162 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
data = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16 *ptr) {
data = *((nv_bfloat162 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16 *ptr) const {
*((nv_bfloat162 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src);
}
// nv_bfloat16 x 4
template <>
struct vec_t<nv_bfloat16, 4> {
uint2 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
*(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// nv_bfloat16 x 8 or more
template <size_t vec_size>
struct vec_t<nv_bfloat16, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
*(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val);
}
}
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<float> *******************/
// float x 1
template <>
struct vec_t<float, 1> {
float data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
FLASHINFER_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
*dst = *src;
}
// float x 2
template <>
struct vec_t<float, 2> {
float2 data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 2>::fill(float val) {
data = make_float2(val, val);
}
FLASHINFER_INLINE void vec_t<float, 2>::load(const float *ptr) {
data = *((float2 *)ptr);
}
FLASHINFER_INLINE void vec_t<float, 2>::store(float *ptr) const {
*((float2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(data))[i];
}
FLASHINFER_INLINE void fill(float val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = make_float4(val, val, val, val);
}
}
FLASHINFER_INLINE void load(const float *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(float *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)dst)[i] = ((float4 *)src)[i];
}
}
};
/******************* vec_t type cast *******************/
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = half(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<nv_bfloat16, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] =
__bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<nv_bfloat16, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = nv_bfloat16(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((nv_bfloat162 *)(&dst.data))[i] =
__float22bfloat162_rn(((float2 *)(&src.data))[i]);
}
}
}
#ifdef FLASHINFER_USE_FP8
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e4m3 *)(&dst.data))[i] =
__nv_fp8x4_e4m3(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e5m2(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e5m2 *)(&dst.data))[i] =
__nv_fp8x4_e5m2(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
#endif // FLASHINFER_USE_FP8
#endif // VEC_DTYPES_CUH_
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h"
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
// NOTE(woosuk): While Punica supports various combinations of input/output
// data types, we limit the supported data types to reduce the binary size.
constexpr bool is_input_float = std::is_same<in_T, float>::value;
constexpr bool is_output_float = std::is_same<out_T, float>::value;
if (is_input_float) {
if (!std::is_same<out_T, W_T>::value) {
return false;
}
} else if (is_output_float) {
if (!std::is_same<in_T, W_T>::value) {
return false;
}
} else if (!(std::is_same<in_T, W_T>::value &&
std::is_same<out_T, W_T>::value)) {
return false;
}
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
#pragma once
#include <torch/all.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset);
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
m.def(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()");
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef __half nv_half;
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
return __hip_bfloat162{val, val};
}
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
return __hip_bfloat162{vall, valr};
}
template <typename T_src, typename T_dst>
__TYPE_CONVERT__HOST_DEVICE__
inline T_dst convert_type(T_src val) {
return static_cast<T_dst>(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__half, float>(__half val) {
return __half2float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half convert_type<float, __half>(float val) {
return __float2half(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
return __float2bfloat16(val);
}
template <typename T>
__TYPE_CONVERT__HOST_DEVICE__
inline T vllm_add(T a, T b) {
return a + b;
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half vllm_add<__half>(__half a, __half b) {
return __hadd(a, b);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
return __hadd(a, b);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
......@@ -66,7 +66,6 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability
$ pip install -e . # This may take 5-10 minutes.
.. tip::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment