Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5467ac31
Unverified
Commit
5467ac31
authored
Jun 09, 2024
by
bnellnm
Committed by
GitHub
Jun 09, 2024
Browse files
[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)
parent
5d7e3d01
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
524 additions
and
115 deletions
+524
-115
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+1
-1
csrc/quantization/gptq_marlin/gptq_marlin.cuh
csrc/quantization/gptq_marlin/gptq_marlin.cuh
+1
-1
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+1
-1
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+1
-1
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+0
-1
csrc/registration.h
csrc/registration.h
+22
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+283
-0
setup.py
setup.py
+1
-1
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+4
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+164
-53
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+5
-5
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+19
-15
vllm/lora/punica.py
vllm/lora/punica.py
+19
-26
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-2
vllm/utils.py
vllm/utils.py
+2
-5
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
5467ac31
...
...
@@ -1867,4 +1867,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
#endif
\ No newline at end of file
#endif
csrc/quantization/gptq_marlin/gptq_marlin.cuh
View file @
5467ac31
#pragma once
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
5467ac31
...
...
@@ -15,7 +15,7 @@
* limitations under the License.
*/
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
5467ac31
...
...
@@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/quantization/squeezellm/quant_cuda_kernel.cu
View file @
5467ac31
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
...
...
csrc/registration.h
0 → 100644
View file @
5467ac31
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
csrc/torch_bindings.cpp
0 → 100644
View file @
5467ac31
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops
.
def
(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
// PagedAttention V2.
ops
.
def
(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Activation ops
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul"
,
torch
::
kCUDA
,
&
silu_and_mul
);
// Activation function used in GeGLU with `none` approximation.
ops
.
def
(
"gelu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul"
,
torch
::
kCUDA
,
&
gelu_and_mul
);
// Activation function used in GeGLU with `tanh` approximation.
ops
.
def
(
"gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul
);
// GELU implementation used in GPT-2.
ops
.
def
(
"gelu_new(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_new"
,
torch
::
kCUDA
,
&
gelu_new
);
// Approximate GELU implementation.
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm"
,
torch
::
kCUDA
,
&
rms_norm
);
// In-place fused Add and RMS Normalization.
ops
.
def
(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops
.
def
(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops
.
def
(
"aqlm_gemm"
,
&
aqlm_gemm
);
ops
.
impl
(
"aqlm_gemm"
,
torch
::
kCUDA
,
&
aqlm_gemm
);
// Decompression method for AQLM.
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
);
ops
.
impl
(
"aqlm_dequant"
,
torch
::
kCUDA
,
&
aqlm_dequant
);
// Quantized GEMM for AWQ.
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
);
ops
.
impl
(
"awq_gemm"
,
torch
::
kCUDA
,
&
awq_gemm
);
// Dequantization for AWQ.
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
);
ops
.
impl
(
"awq_dequantize"
,
torch
::
kCUDA
,
&
awq_dequantize
);
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
);
ops
.
impl
(
"marlin_gemm"
,
torch
::
kCUDA
,
&
marlin_gemm
);
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
);
ops
.
impl
(
"gptq_marlin_24_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_24_gemm
);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
ops
.
impl
(
"gptq_marlin_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_gemm
);
// gptq_marlin repack from GPTQ.
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
"cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_dq"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm_dq
);
#endif
// Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
);
ops
.
impl
(
"gptq_gemm"
,
torch
::
kCUDA
,
&
gptq_gemm
);
// Post processing for GPTQ.
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()"
);
ops
.
impl
(
"squeezellm_gemm"
,
torch
::
kCUDA
,
&
squeezellm_gemm
);
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"
);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// Compute FP8 quantized tensor and scaling factor.
ops
.
def
(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()"
);
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
ops
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()"
);
ops
.
impl
(
"static_scaled_int8_quant"
,
torch
::
kCUDA
,
&
static_scaled_int8_quant
);
// Compute int8 quantized tensor and scaling factor
ops
.
def
(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_int8_quant
);
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops
.
def
(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"
);
cache_ops
.
impl
(
"swap_blocks"
,
torch
::
kCUDA
,
&
swap_blocks
);
// Copy the cache blocks from src to dst.
cache_ops
.
def
(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()"
);
cache_ops
.
impl
(
"copy_blocks"
,
torch
::
kCUDA
,
&
copy_blocks
);
// Reshape the key and value tensors and cache them.
cache_ops
.
def
(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCUDA
,
&
reshape_and_cache
);
// Reshape the key and value tensors and cache them.
cache_ops
.
def
(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache_flash"
,
torch
::
kCUDA
,
&
reshape_and_cache_flash
);
// Convert the key and value cache to fp8 data type.
cache_ops
.
def
(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
"kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"convert_fp8"
,
torch
::
kCUDA
,
&
convert_fp8
);
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cuda_utils
),
cuda_utils
)
{
// Cuda utils
// Gets the specified device attribute.
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
);
cuda_utils
.
impl
(
"get_device_attribute"
,
torch
::
kCUDA
,
&
get_device_attribute
);
// Gets the maximum shared memory per block device attribute.
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
);
cuda_utils
.
impl
(
"get_max_shared_memory_per_block_device_attribute"
,
torch
::
kCUDA
,
&
get_max_shared_memory_per_block_device_attribute
);
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_custom_ar
),
custom_ar
)
{
// Custom all-reduce kernels
custom_ar
.
def
(
"init_custom_ar"
,
&
init_custom_ar
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
def
(
"should_custom_ar"
,
&
should_custom_ar
);
custom_ar
.
impl
(
"should_custom_ar"
,
torch
::
kCUDA
,
&
should_custom_ar
);
custom_ar
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
custom_ar
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
custom_ar
.
def
(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()"
);
custom_ar
.
impl
(
"all_reduce_unreg"
,
torch
::
kCUDA
,
&
all_reduce_unreg
);
custom_ar
.
def
(
"dispose"
,
&
dispose
);
custom_ar
.
impl
(
"dispose"
,
torch
::
kCPU
,
&
dispose
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
);
custom_ar
.
impl
(
"meta_size"
,
torch
::
kCPU
,
&
meta_size
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
);
custom_ar
.
impl
(
"register_buffer"
,
torch
::
kCUDA
,
&
register_buffer
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
impl
(
"get_graph_buffer_ipc_meta"
,
torch
::
kCPU
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
custom_ar
.
impl
(
"register_graph_buffers"
,
torch
::
kCPU
,
&
register_graph_buffers
);
}
#endif
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
setup.py
View file @
5467ac31
...
...
@@ -60,7 +60,7 @@ def remove_prefix(text, prefix):
class
CMakeExtension
(
Extension
):
def
__init__
(
self
,
name
:
str
,
cmake_lists_dir
:
str
=
'.'
,
**
kwa
)
->
None
:
super
().
__init__
(
name
,
sources
=
[],
**
kwa
)
super
().
__init__
(
name
,
sources
=
[],
py_limited_api
=
True
,
**
kwa
)
self
.
cmake_lists_dir
=
os
.
path
.
abspath
(
cmake_lists_dir
)
...
...
tests/kernels/test_int8_quant.py
View file @
5467ac31
import
pytest
import
torch
from
vllm._C
import
ops
# ruff: noqa: F401
import
vllm._C
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
...
@@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
scales_out
=
torch
.
empty_like
(
scales
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
ops
.
dynamic_scaled_int8_quant
(
ops_out
,
x
,
scales_out
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
ops_out
,
x
,
scales_out
)
assert
torch
.
allclose
(
scales_out
,
scales
)
assert
torch
.
allclose
(
torch_out
,
ops_out
,
...
...
@@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
scale_argument
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
ops
.
static_scaled_int8_quant
(
out2
,
x
,
scale_argument
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_argument
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1
)
# big atol to account for rounding errors
vllm/_custom_ops.py
View file @
5467ac31
from
typing
import
Optional
,
Tuple
,
Type
import
contextlib
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
try
:
from
vllm._C
import
cache_ops
as
vllm_cache_ops
from
vllm._C
import
ops
as
vllm_ops
import
vllm._C
except
ImportError
as
e
:
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
with
contextlib
.
suppress
(
ImportError
):
# ruff: noqa: F401
import
vllm._punica_C
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
op
,
overloads
=
torch
.
_C
.
_jit_get_operation
(
op_name
)
return
op
is
not
None
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
silu_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_and_mul
(
out
,
x
)
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_tanh_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_fast
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_new
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_new
(
out
,
x
)
# page attention ops
...
...
@@ -53,7 +65,7 @@ def paged_attention_v1(
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
vllm_ops
.
paged_attention_v1
(
torch
.
ops
.
_C
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -83,7 +95,7 @@ def paged_attention_v2(
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
vllm_ops
.
paged_attention_v2
(
torch
.
ops
.
_C
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
...
...
@@ -100,8 +112,8 @@ def rotary_embedding(
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
)
->
None
:
vllm_ops
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -109,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
# layer norm ops
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
vllm_ops
.
rms_norm
(
out
,
input
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
rms_norm
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
vllm_ops
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# quantization ops
...
...
@@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_dequantize
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thx
,
thy
)
return
torch
.
ops
.
_C
.
awq_dequantize
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thx
,
thy
)
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
return
torch
.
ops
.
_C
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
# gptq
...
...
@@ -144,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
return
torch
.
ops
.
_C
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
vllm_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
torch
.
ops
.
_C
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
size_m
,
size_n
,
size_k
)
return
torch
.
ops
.
_C
.
marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
size_m
,
size_n
,
size_k
)
# marlin_24
...
...
@@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
return
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
# cutlass
...
...
@@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
vllm_ops
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
scale_a
,
scale_b
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
scale_a
,
scale_b
)
return
out
...
...
@@ -198,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
return
torch
.
ops
.
_C
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
...
@@ -220,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
# fp8
...
...
@@ -259,9 +272,9 @@ def scaled_fp8_quant(
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
vllm_ops
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
vllm_ops
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
...
...
@@ -284,14 +297,14 @@ def scaled_int8_quant(
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
if
scale
is
not
None
:
# static-per-tensor quantization.
vllm_ops
.
static_scaled_int8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# dynamic-per-token quantization.
input_scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
vllm_ops
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
)
return
output
,
input_scales
...
...
@@ -300,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
)
torch
.
ops
.
_C
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
)
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indicies
:
torch
.
Tensor
,
gating_output
:
float
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
)
def
reshape_and_cache
(
...
...
@@ -314,8 +334,9 @@ def reshape_and_cache(
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
reshape_and_cache_flash
(
...
...
@@ -326,25 +347,115 @@ def reshape_and_cache_flash(
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm
_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
torch
.
ops
.
_C
_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm
_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
torch
.
ops
.
_C
_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
def
convert_fp8
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
float
=
1.0
,
kv_dtype
:
str
=
"fp8"
)
->
None
:
vllm_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
torch
.
ops
.
_C_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
return
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
(
attribute
,
device
)
def
get_max_shared_memory_per_block_device_attribute
(
device
:
int
)
->
int
:
# ruff: noqa: E501
return
torch
.
ops
.
_C_cuda_utils
.
get_max_shared_memory_per_block_device_attribute
(
device
)
# custom ar
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
should_custom_ar
(
inp
:
torch
.
Tensor
,
max_size
:
int
,
world_size
:
int
,
full_nvlink
:
bool
)
->
bool
:
return
torch
.
ops
.
_C_custom_ar
.
should_custom_ar
(
inp
,
max_size
,
world_size
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
#TODO: cuda_utils, custom_ar
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
str
],
List
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# punica
def
dispatch_bgmv
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
h_in
:
int
,
h_out
:
int
,
y_offset
:
int
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
h_in
,
h_out
,
y_offset
,
)
vllm/attention/backends/flash_attn.py
View file @
5467ac31
...
...
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
._C
import
cache_
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
...
...
@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
...
...
@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
...
...
@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
cache_
ops
.
reshape_and_cache_flash
(
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
5467ac31
...
...
@@ -6,6 +6,7 @@ import torch.distributed as dist
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
(
...
...
@@ -15,7 +16,11 @@ from vllm.logger import init_logger
try
:
import
pynvml
from
vllm._C
import
custom_ar
# Simulate ImportError if custom_ar ops are not supported.
if
not
ops
.
is_custom_op_supported
(
"_C_custom_ar::meta_size"
):
raise
ImportError
(
"custom_ar"
,
__file__
)
custom_ar
=
True
@
contextmanager
def
_nvml
():
...
...
@@ -27,7 +32,7 @@ try:
except
ImportError
:
# For AMD GPUs
custom_ar
=
Non
e
custom_ar
=
Fals
e
pynvml
=
None
@
contextmanager
...
...
@@ -97,7 +102,7 @@ class CustomAllreduce:
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
custom_ar
is
None
:
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
...
...
@@ -175,7 +180,7 @@ class CustomAllreduce:
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
self
.
meta
=
torch
.
zeros
(
ops
.
meta_size
()
+
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
...
...
@@ -196,9 +201,8 @@ class CustomAllreduce:
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
full_nvlink
self
.
_ptr
=
custom_ar
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
@
contextmanager
...
...
@@ -252,31 +256,31 @@ class CustomAllreduce:
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
handles
,
offsets
=
self
.
_get_ipc_meta
(
inp
)
custom_ar
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
handle
,
offset
=
custom_ar
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
custom_ar
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
return
custom_ar
.
should_custom_ar
(
inp
,
self
.
max_size
,
self
.
world_size
,
self
.
full_nvlink
)
return
ops
.
should_custom_ar
(
inp
,
self
.
max_size
,
self
.
world_size
,
self
.
full_nvlink
)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
custom_ar
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -304,7 +308,7 @@ class CustomAllreduce:
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
custom_ar
.
dispose
(
self
.
_ptr
)
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
def
__del__
(
self
):
...
...
vllm/lora/punica.py
View file @
5467ac31
...
...
@@ -4,16 +4,21 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
def
_check_punica_support
():
if
ops
.
is_custom_op_supported
(
"_punica_C::dispatch_bgmv"
):
return
def
_raise_import_error
(
e
):
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
raise
ImportError
(
"punica LoRA kernels require compute capability >= 8.0"
)
from
e
"punica LoRA kernels require compute capability >= 8.0"
)
else
:
raise
ImportError
(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set."
)
from
e
"was set."
)
def
bgmv
(
...
...
@@ -41,12 +46,9 @@ def bgmv(
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
_check_punica_support
()
p
unica_kernel
s
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
o
ps
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
punica_kernels
.
dispatch_bgmv_low_level
(
_check_punica_support
()
ops
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
...
...
@@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor,
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
_check_punica_support
()
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
...
...
@@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor,
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
punica_kernels
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
)
punica_kernels
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
scale
)
ops
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
)
ops
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
scale
)
def
add_lora_slice
(
y
:
torch
.
Tensor
,
...
...
@@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
_check_punica_support
()
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
...
...
@@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor,
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
p
unica_kernel
s
.
dispatch_bgmv_low_level
(
o
ps
.
dispatch_bgmv_low_level
(
buffer
,
x
,
wa_t_all
,
...
...
@@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor,
buffer
.
size
(
1
),
0
,
)
p
unica_kernel
s
.
dispatch_bgmv_low_level
(
o
ps
.
dispatch_bgmv_low_level
(
y
,
buffer
,
wb_t_all
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
5467ac31
...
...
@@ -8,7 +8,6 @@ import torch
import
triton
import
triton.language
as
tl
import
vllm._moe_C
as
moe_kernels
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
...
...
@@ -355,7 +354,7 @@ def fused_topk(
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernel
s
.
topk_softmax
(
op
s
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
...
...
vllm/utils.py
View file @
5467ac31
...
...
@@ -22,6 +22,7 @@ import psutil
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
enable_trace_function_call
,
init_logger
T
=
TypeVar
(
"T"
)
...
...
@@ -148,12 +149,8 @@ def is_neuron() -> bool:
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from
vllm._C
import
cuda_utils
max_shared_mem
=
(
cuda_util
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
op
s
.
get_max_shared_memory_per_block_device_attribute
(
gpu
))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert
max_shared_mem
>
0
,
"max_shared_mem can not be zero"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment