Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eca2c5f7
Unverified
Commit
eca2c5f7
authored
Oct 17, 2024
by
bnellnm
Committed by
GitHub
Oct 17, 2024
Browse files
[Bugfix] Fix support for dimension like integers and ScalarType (#9299)
parent
0f41fbe5
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
127 additions
and
670 deletions
+127
-670
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+6
-8
CMakeLists.txt
CMakeLists.txt
+0
-18
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+4
-205
csrc/core/torch_bindings.cpp
csrc/core/torch_bindings.cpp
+0
-16
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+8
-7
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+3
-2
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+12
-11
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+9
-7
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+8
-7
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+20
-25
python_only_dev.py
python_only_dev.py
+0
-1
setup.py
setup.py
+0
-7
tests/compile/utils.py
tests/compile/utils.py
+4
-4
tests/kernels/test_machete_gemm.py
tests/kernels/test_machete_gemm.py
+5
-4
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+13
-3
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+2
-2
tests/test_scalartype.py
tests/test_scalartype.py
+2
-2
tools/report_build_time_ninja.py
tools/report_build_time_ninja.py
+0
-1
vllm/_core_ext.py
vllm/_core_ext.py
+0
-278
vllm/_custom_ops.py
vllm/_custom_ops.py
+31
-62
No files found.
.buildkite/test-pipeline.yaml
View file @
eca2c5f7
...
...
@@ -230,14 +230,12 @@ steps:
commands
:
-
pytest -v -s compile/test_basic_correctness.py
# TODO: re-write in comparison tests, and fix symbolic shape
# for quantization ops.
# - label: "PyTorch Fullgraph Test" # 18min
# source_file_dependencies:
# - vllm/
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py
-
label
:
"
PyTorch
Fullgraph
Test"
# 18min
source_file_dependencies
:
-
vllm/
-
tests/compile
commands
:
-
pytest -v -s compile/test_full_graph.py
-
label
:
Kernels Test %N
# 1h each
mirror_hardwares
:
[
amd
]
...
...
CMakeLists.txt
View file @
eca2c5f7
...
...
@@ -83,24 +83,6 @@ endif()
#
find_package
(
Torch REQUIRED
)
#
message
(
STATUS
"Enabling core extension."
)
# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)
set
(
VLLM_EXT_SRC
"csrc/core/torch_bindings.cpp"
)
define_gpu_extension_target
(
_core_C
DESTINATION vllm
LANGUAGE CXX
SOURCES
${
VLLM_EXT_SRC
}
COMPILE_FLAGS
${
CXX_COMPILE_FLAGS
}
USE_SABI 3
WITH_SOABI
)
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
...
...
csrc/core/scalar_type.hpp
View file @
eca2c5f7
#pragma once
#include <torch/custom_class.h>
// For TORCH_CHECK
#include <torch/library.h>
namespace
vllm
{
...
...
@@ -9,12 +10,7 @@ namespace vllm {
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
// can be used as a argument for custom operators, helping to simplify these
// interfaces.
//
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
...
...
@@ -308,204 +304,7 @@ class ScalarType {
}
};
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class
ScalarTypeTorch
:
public
torch
::
CustomClassHolder
,
public
ScalarType
{
public:
ScalarTypeTorch
(
int64_t
exponent
,
int64_t
mantissa
,
int64_t
bias
,
bool
_signed
)
:
ScalarType
(
exponent
,
mantissa
,
bias
,
_signed
){};
ScalarTypeTorch
(
ScalarType
type
)
:
ScalarType
(
type
){};
using
Base
=
ScalarType
;
using
Self
=
ScalarTypeTorch
;
using
SelfPtr
=
c10
::
intrusive_ptr
<
Self
>
;
static
void
check_size_bits
(
int64_t
size_bits
,
bool
signed_
)
{
TORCH_CHECK
(
size_bits
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"size_bits bit width is too large to be represented"
);
}
static
void
check_bias
(
int64_t
bias
)
{
using
Bias
=
decltype
(
std
::
declval
<
Self
>
().
bias
);
TORCH_CHECK
(
bias
<=
std
::
numeric_limits
<
Bias
>::
max
()
&&
bias
>=
std
::
numeric_limits
<
Bias
>::
min
(),
"bias too large or small to be represented"
);
}
static
void
check_exponent
(
int64_t
exponent
)
{
TORCH_CHECK
(
exponent
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
exponent
)
>::
max
(),
"exponent bit width is too large to be represented"
);
}
static
void
check_mantissa
(
int64_t
mantissa
)
{
TORCH_CHECK
(
mantissa
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"mantissa bit width is too large to be represented"
);
}
static
SelfPtr
int_
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
int_
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
uint
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
uint
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
float_IEEE754
(
int64_t
exponent
,
int64_t
mantissa
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_IEEE754
(
exponent
,
mantissa
));
}
static
SelfPtr
float_
(
int64_t
exponent
,
int64_t
mantissa
,
bool
finite_values_only
,
int64_t
nan_repr
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_
(
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t
len
()
const
{
throw
c10
::
TypeError
({
__func__
,
__FILE__
,
static_cast
<
uint32_t
>
(
__LINE__
)},
"__len__ not implemented"
);
return
0
;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
obj_flatten
()
const
{
return
{{
"ScalarType"
,
id
()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static
SelfPtr
obj_unflatten
(
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
const
&
flat_type
)
{
return
c10
::
make_intrusive
<
Self
>
(
from_id
(
std
::
get
<
1
>
(
std
::
get
<
0
>
(
flat_type
))));
}
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
auto
getter_func_helper
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
if
constexpr
(
std
::
is_member_function_pointer_v
<
decltype
(
field
)
>
)
{
return
(
self
.
get
()
->*
field
)();
}
else
{
return
self
.
get
()
->*
field
;
}
};
auto
getter_func
=
[
field
=
std
::
move
(
field
),
getter_func_helper
=
std
::
move
(
getter_func_helper
)](
SelfPtr
const
&
self
)
{
auto
val
=
getter_func_helper
(
self
);
// upconvert uint8_t, int32_t etc. to int64_t for python
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
return
static_cast
<
int64_t
>
(
val
);
}
else
{
return
val
;
}
};
cls
.
def_property
(
name
,
getter_func
);
}
template
<
typename
MemberFunc
,
typename
Cls
>
static
void
bind_function
(
torch
::
class_
<
Self
>&
cls
,
const
std
::
string
&
name
,
MemberFunc
Cls
::*
member
)
{
cls
.
def
(
name
,
[
member
=
std
::
move
(
member
)](
SelfPtr
const
&
self
)
{
return
(
self
.
get
()
->*
member
)();
});
}
template
<
typename
Func
>
static
void
bind_function
(
torch
::
class_
<
Self
>&
cls
,
const
std
::
string
&
name
,
Func
func
)
{
cls
.
def
(
name
,
func
);
}
template
<
typename
Func
>
static
void
bind_static_function
(
torch
::
class_
<
Self
>&
cls
,
const
std
::
string
&
name
,
Func
func
)
{
cls
.
def_static
(
name
,
func
);
}
static
void
bind_class
(
torch
::
Library
&
lib
)
{
auto
cls
=
lib
.
class_
<
ScalarTypeTorch
>
(
"ScalarType"
)
.
def
(
torch
::
init
<
int64_t
,
int64_t
,
int64_t
,
bool
>
());
// Bind Properties
bind_readonly_property
(
cls
,
"mantissa"
,
&
Base
::
mantissa
);
bind_readonly_property
(
cls
,
"exponent"
,
&
Base
::
exponent
);
bind_readonly_property
(
cls
,
"bias"
,
&
Base
::
bias
);
bind_readonly_property
(
cls
,
"signed"
,
&
Base
::
is_signed
);
bind_readonly_property
(
cls
,
"size_bits"
,
&
Base
::
size_bits
);
// Bind member functions
bind_function
(
cls
,
"is_signed"
,
&
Base
::
is_signed
);
bind_function
(
cls
,
"is_integer"
,
&
Base
::
is_integer
);
bind_function
(
cls
,
"is_floating_point"
,
&
Base
::
is_floating_point
);
bind_function
(
cls
,
"is_ieee_754"
,
&
Base
::
is_ieee_754
);
bind_function
(
cls
,
"has_nans"
,
&
Base
::
has_nans
);
bind_function
(
cls
,
"has_infs"
,
&
Base
::
has_infs
);
bind_function
(
cls
,
"has_bias"
,
&
Base
::
has_bias
);
bind_function
(
cls
,
"max"
,
[](
SelfPtr
const
&
self
)
{
return
std
::
visit
([](
auto
arg
)
{
return
c10
::
IValue
(
arg
);
},
self
.
get
()
->
max
());
});
bind_function
(
cls
,
"min"
,
[](
SelfPtr
const
&
self
)
{
return
std
::
visit
([](
auto
arg
)
{
return
c10
::
IValue
(
arg
);
},
self
.
get
()
->
min
());
});
bind_function
(
cls
,
"__len__"
,
&
ScalarTypeTorch
::
len
);
bind_function
(
cls
,
"__str__"
,
&
Base
::
str
);
bind_function
(
cls
,
"__eq__"
,
[](
SelfPtr
const
&
self
,
SelfPtr
const
&
other
)
{
return
*
self
==
*
other
;
});
bind_function
(
cls
,
"__repr__"
,
[](
SelfPtr
const
&
self
)
{
return
"ScalarType."
+
self
.
get
()
->
str
();
});
bind_function
(
cls
,
"__obj_flatten__"
,
&
ScalarTypeTorch
::
obj_flatten
);
bind_static_function
(
cls
,
"__obj_unflatten__"
,
&
ScalarTypeTorch
::
obj_unflatten
);
// Bind static functions (convenience constructors)
bind_static_function
(
cls
,
"int_"
,
&
ScalarTypeTorch
::
int_
);
bind_static_function
(
cls
,
"uint"
,
&
ScalarTypeTorch
::
uint
);
bind_static_function
(
cls
,
"float_IEEE754"
,
&
ScalarTypeTorch
::
float_IEEE754
);
bind_static_function
(
cls
,
"float_"
,
&
ScalarTypeTorch
::
float_
);
}
};
using
ScalarTypeId
=
int64_t
;
using
ScalarTypeTorchPtr
=
c10
::
intrusive_ptr
<
ScalarTypeTorch
>
;
using
ScalarTypeId
=
ScalarType
::
Id
;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
...
...
csrc/core/torch_bindings.cpp
deleted
100644 → 0
View file @
0f41fbe5
#include <torch/library.h>
#include "scalar_type.hpp"
#include "registration.h"
// Note the CORE exstension will be built for (almost) all hardware targets so
// new additions must account for this. (currently not built for TPU and Neuron)
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
lib
)
{
// ScalarType, a custom class for representing data types that supports
// quantized types, declared here so it can be used when creating interfaces
// for custom ops.
vllm
::
ScalarTypeTorch
::
bind_class
(
lib
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/moe/marlin_moe_ops.cu
View file @
eca2c5f7
...
...
@@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe(
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarType
TorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
vllm
::
ScalarType
Id
const
b_q_type
_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
->
str
());
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
->
str
());
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
int
max_par
=
4
;
...
...
@@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe(
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
...
...
csrc/moe/torch_bindings.cpp
View file @
eca2c5f7
...
...
@@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
// conditionally compiled so impl registration is in source file
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
eca2c5f7
...
...
@@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarType
TorchPtr
const
&
b_q_type
,
vllm
::
ScalarType
Id
const
b_q_type
_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
...
...
@@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarType
TorchPtr
const
&
b_q_type
,
vllm
::
ScalarType
Id
const
&
b_q_type
_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
||
*
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
->
str
());
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
->
str
());
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
...
...
@@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
...
...
@@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
...
...
@@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
}
\ No newline at end of file
}
csrc/quantization/machete/machete_pytorch.cu
View file @
eca2c5f7
...
...
@@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
// Interface
//
std
::
vector
<
std
::
string
>
supported_schedules
(
ScalarType
TorchPtr
const
&
btype
)
{
std
::
vector
<
std
::
string
>
supported_schedules
(
ScalarType
Id
const
btype
_id
)
{
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
vllm
::
ScalarType
b_type
=
ScalarType
::
from_id
(
btype_id
);
return
scalar_type_dispatch
(
b_type
,
[
&
](
auto
BType
)
{
return
GemmDispatcher
<
half_t
,
decltype
(
BType
)
>::
supported_schedules
();
});
#else
...
...
@@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
}
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
ScalarType
TorchPtr
const
&
btype
,
ScalarType
Id
const
btype
_id
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
...
...
@@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
)
{
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
ScalarType
const
btype
=
ScalarType
::
from_id
(
btype_id
);
auto
args
=
PyTorchArguments
{.
A
=
A
,
.
B
=
B
,
.
scales
=
scales
,
...
...
@@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
.
beta
=
beta
,
.
schedule
=
schedule
};
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
scalar_type_dispatch
(
btype
,
[
&
](
auto
BType
)
{
return
AT_DISPATCH_SUPPORTED_COMPUTE_TYPES
(
A
.
scalar_type
(),
"machete_gemm"
,
[
&
]
{
using
ComputeType
=
equivalent_cutlass_type_t
<
scalar_t
>
;
...
...
@@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
#endif
}
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarType
TorchPtr
const
&
btype
)
{
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
ScalarTypeId
const
btype_id
)
{
ScalarType
const
btype
=
ScalarType
::
from_id
(
btype_id
);
return
scalar_type_dispatch
(
btype
,
[
&
](
auto
BType
)
{
return
PrepackBDispatcher
<
half_t
,
decltype
(
BType
),
half_t
>::
dispatch
(
B
);
});
}
...
...
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
eca2c5f7
...
...
@@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarType
TorchPtr
const
&
b_q_type
,
vllm
::
ScalarType
Id
const
b_q_type
_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
@@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarType
TorchPtr
const
&
b_q_type
,
vllm
::
ScalarType
Id
const
b_q_type
_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
// Verify num_bits
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"num_bits must be uint4b8 or uint8b128. Got = "
,
b_q_type
->
str
());
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"num_bits must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
// Verify M
TORCH_CHECK
(
size_m
==
a
.
size
(
0
),
...
...
@@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24
::
marlin_cuda_2_4
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
b_meta
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
size_n
,
size_m
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
->
size_bits
(),
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_m
,
sms
,
max_par
);
b_q_type
.
size_bits
(),
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_m
,
sms
,
max_par
);
return
c
;
}
...
...
csrc/torch_bindings.cpp
View file @
eca2c5f7
...
...
@@ -140,13 +140,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantized GEMM for AWQ.
ops
.
def
(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros,
i
nt split_k_iters) -> Tensor"
);
"Tensor _zeros,
SymI
nt split_k_iters) -> Tensor"
);
ops
.
impl
(
"awq_gemm"
,
torch
::
kCUDA
,
&
awq_gemm
);
// Dequantization for AWQ.
ops
.
def
(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros,
i
nt split_k_iters, int thx, int thy) -> Tensor"
);
"Tensor _zeros,
SymI
nt split_k_iters, int thx, int thy) -> Tensor"
);
ops
.
impl
(
"awq_dequantize"
,
torch
::
kCUDA
,
&
awq_dequantize
);
// Note about marlin kernel 'workspace' arguments:
...
...
@@ -166,32 +166,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"
);
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
"Tensor"
);
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"
__torch__.torch.classes._core_C.ScalarType
b_q_type, "
"
i
nt size_m,
i
nt size_n,
i
nt size_k) -> Tensor"
);
"
int
b_q_type, "
"
SymI
nt size_m,
SymI
nt size_n,
SymI
nt size_k) -> Tensor"
);
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops
.
def
(
"machete_supported_schedules(int btype) -> str[]"
);
ops
.
def
(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]"
);
ops
.
def
(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
"machete_gemm(Tensor A, Tensor B, int btype, "
" Tensor? scales, Tensor? zeros, int? group_size, "
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor"
);
ops
.
def
(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor"
);
ops
.
def
(
"machete_prepack_B(Tensor B, int btype) -> Tensor"
);
// conditionally compiled so impl registration is in source file
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
...
...
@@ -201,8 +195,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"
__torch__.torch.classes._core_C.ScalarType
b_q_type, "
"
i
nt size_m,
i
nt size_n,
i
nt size_k, bool is_k_full, "
"
int
b_q_type, "
"
SymI
nt size_m,
SymI
nt size_n,
SymI
nt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor"
);
// conditionally compiled so impl registration is in source file
...
...
@@ -219,32 +213,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type,
i
nt m,
i
nt n) -> Tensor"
);
ops
.
def
(
"ggml_dequantize(Tensor W, int type,
SymI
nt m,
SymI
nt n) -> Tensor"
);
ops
.
impl
(
"ggml_dequantize"
,
torch
::
kCUDA
,
&
ggml_dequantize
);
// mmvq kernel for GGML.
ops
.
def
(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type,
i
nt row) "
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type,
SymI
nt row) "
"-> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_vec_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_vec_a8
);
// mmq kernel for GGML.
ops
.
def
(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor"
);
ops
.
def
(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits,
i
nt size_m,
i
nt size_n, "
"
i
nt size_k) -> Tensor"
);
"Tensor! workspace, int num_bits,
SymI
nt size_m,
SymI
nt size_n, "
"
SymI
nt size_k) -> Tensor"
);
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops
.
def
(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace,
i
nt size_m,
i
nt size_n, "
"
i
nt size_k) -> Tensor"
);
"Tensor! workspace,
SymI
nt size_m,
SymI
nt size_n, "
"
SymI
nt size_k) -> Tensor"
);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
...
...
python_only_dev.py
View file @
eca2c5f7
...
...
@@ -39,7 +39,6 @@ assert cwd != package_path, "should not import from the current directory"
files_to_copy
=
[
"vllm/_C.abi3.so"
,
"vllm/_core_C.abi3.so"
,
"vllm/_moe_C.abi3.so"
,
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so"
,
"vllm/vllm_flash_attn/flash_attn_interface.py"
,
...
...
setup.py
View file @
eca2c5f7
...
...
@@ -290,10 +290,6 @@ def _build_custom_ops() -> bool:
return
_is_cuda
()
or
_is_hip
()
or
_is_cpu
()
def
_build_core_ext
()
->
bool
:
return
not
(
_is_neuron
()
or
_is_tpu
()
or
_is_openvino
()
or
_is_xpu
())
def
get_hipcc_rocm_version
():
# Run the hipcc --version command
result
=
subprocess
.
run
([
'hipcc'
,
'--version'
],
...
...
@@ -456,9 +452,6 @@ def get_requirements() -> List[str]:
ext_modules
=
[]
if
_build_core_ext
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._core_C"
))
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
...
...
tests/compile/utils.py
View file @
eca2c5f7
...
...
@@ -69,11 +69,11 @@ def check_full_graph_support(model,
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
optimization_level
)
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
# Inductor doesn't support fp8 and the base meta llama uses too
# much memory.
quantization
=
model_kwargs
.
get
(
"quantization"
)
if
(
quantization
==
"fp8"
or
quantization
==
"gptq_marlin"
or
quantization
==
"gptq_marlin_24"
)
and
optimization_level
>=
CompilationLevel
.
INDUCTOR
:
if
((
quantization
==
"fp8"
or
model
==
"meta-llama/Meta-Llama-3-8B"
)
and
optimization_level
>=
CompilationLevel
.
INDUCTOR
):
return
prompts
=
[
...
...
tests/kernels/test_machete_gemm.py
View file @
eca2c5f7
...
...
@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
))
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
.
id
))
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
...
...
@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule
=
schedule
,
)
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
(
a
,
w_q_machete
,
wtype
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
(
a
,
w_q_machete
,
wtype
.
id
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
...
...
tests/kernels/test_marlin_gemm.py
View file @
eca2c5f7
...
...
@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
...
...
@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert
max_diff
<
0.04
# TODO: find better way to test this?
@
torch
.
compile
(
fullgraph
=
True
)
def
marlin_24_gemm_tester
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
):
return
ops
.
gptq_marlin_24_gemm
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
...
...
@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
,
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
workspace_24
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_
marlin_24_gemm
(
output
=
marlin_24_gemm
_tester
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
...
...
tests/kernels/test_moe.py
View file @
eca2c5f7
...
...
@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad
=
False
)
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
.
id
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
...
...
tests/test_scalartype.py
View file @
eca2c5f7
...
...
@@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple):
max
=
torch
.
iinfo
(
torch_type
).
max
print
(
t
,
min
,
max
,
t
.
min
(),
t
.
max
())
assert
min
==
t
.
min
()
assert
max
==
t
.
max
()
assert
min
==
t
.
min
()
,
f
"min:
{
min
}
!=
{
t
.
min
()
}
"
assert
max
==
t
.
max
()
,
f
"max:
{
max
}
!=
{
t
.
max
()
}
"
tools/report_build_time_ninja.py
View file @
eca2c5f7
...
...
@@ -16,7 +16,6 @@ Typical output looks like this:
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
...
...
vllm/_core_ext.py
deleted
100644 → 0
View file @
0f41fbe5
import
importlib.util
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
core_C_available
=
importlib
.
util
.
find_spec
(
'._core_C'
,
'vllm'
)
is
not
None
# Mirrors enum in `core/scalar_type.hpp`
class
NanRepr
(
Enum
):
NONE
=
0
# nans are not supported
IEEE_754
=
1
# nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN
=
2
# nans are: Exp all 1s, mantissa all 1s
if
TYPE_CHECKING
or
not
core_C_available
:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from
dataclasses
import
dataclass
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent
:
int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa
:
int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias
:
int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed
:
bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only
:
bool
=
False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@
property
def
size_bits
(
self
):
return
self
.
exponent
+
self
.
mantissa
+
int
(
self
.
signed
)
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise
NotImplementedError
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise
NotImplementedError
def
is_signed
(
self
)
->
bool
:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
return
self
.
exponent
!=
0
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
return
self
.
exponent
==
0
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
return
self
.
bias
!=
0
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
not
self
.
_finite_values_only
def
__str__
(
self
)
->
str
:
raise
NotImplementedError
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
# Convenience Constructors
#
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
"Create a signed integer scalar type (size_bits includes sign-bit)."
return
cls
(
size_bits
-
1
,
size_bits
,
bias
if
bias
else
0
,
True
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
"""Create a unsigned integer scalar type."""
return
cls
(
size_bits
,
size_bits
,
bias
if
bias
else
0
,
False
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
)
->
'ScalarType'
:
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
nan_repr
)
elif
core_C_available
:
try
:
import
vllm._core_C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
if
(
hasattr
(
torch
,
"_library"
)
and
hasattr
(
torch
.
_library
,
"register_fake_class"
)):
# Needed for dynamo support of ScalarType.
@
torch
.
_library
.
register_fake_class
(
"_core_C::ScalarType"
)
class
FakeScalarType
:
def
__init__
(
self
,
scalar_type
):
self
.
ScalarType
=
scalar_type
def
bias_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
bias
def
exponent_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
exponent
def
mantissa_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
mantissa
def
signed_getter
(
self
)
->
bool
:
return
self
.
ScalarType
.
signed
def
size_bits_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
def
min
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
min
()
def
max
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
max
()
def
is_signed
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_signed
()
def
is_floating_point
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_floating_point
()
def
is_integer
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_integer
()
def
has_bias
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_bias
()
def
has_infs
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_infs
()
def
has_nans
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_nans
()
def
is_ieee_754
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_ieee_754
()
def
__str__
(
self
)
->
str
:
return
self
.
ScalarType
.
__str__
()
def
__repr__
(
self
)
->
str
:
return
self
.
ScalarType
.
__repr__
()
def
__len__
(
self
)
->
int
:
return
self
.
ScalarType
.
__len__
()
def
__obj_flatten__
(
self
)
->
Tuple
[
Tuple
[
str
,
Any
],
...]:
return
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_flatten__
(
self
.
ScalarType
)
@
classmethod
def
__obj_unflatten__
(
cls
,
flat_type
:
Tuple
[
Tuple
[
str
,
Any
],
...])
->
'ScalarType'
:
return
cls
(
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_unflatten__
(
flat_type
))
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
int_
(
size_bits
,
bias
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
uint
(
size_bits
,
bias
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_IEEE754
(
exponent
,
mantissa
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_
(
exponent
,
mantissa
,
finite_values_only
,
nan_repr
)
vllm/_custom_ops.py
View file @
eca2c5f7
...
...
@@ -6,9 +6,9 @@ import torch
import
torch.library
import
vllm.envs
as
envs
from
vllm._core_ext
import
ScalarType
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
logger
=
init_logger
(
__name__
)
...
...
@@ -306,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
b_q_type
,
size_m
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
)
...
...
@@ -316,8 +316,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
b_q_type
:
ScalarType
,
size_m
:
torch
.
SymInt
,
size_n
:
torch
.
SymInt
,
size_k
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::gptq_marlin_gemm"
)
...
...
@@ -329,17 +330,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
i
nt
,
size_n
:
i
nt
,
size_k
:
i
nt
,
size_m
:
torch
.
SymI
nt
,
size_n
:
torch
.
SymI
nt
,
size_k
:
torch
.
SymI
nt
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
)
->
torch
.
Tensor
:
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
torch
.
SymInt
,
n
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
...
...
@@ -347,7 +349,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
i
nt
,
row
:
torch
.
SymI
nt
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
...
...
@@ -356,7 +358,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
i
nt
,
row
:
torch
.
SymI
nt
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
...
...
@@ -365,8 +367,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def
_marlin_qqq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
s_group
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
i
nt
,
size_n
:
i
nt
,
size_k
:
i
nt
)
->
torch
.
Tensor
:
size_m
:
torch
.
SymI
nt
,
size_n
:
torch
.
SymI
nt
,
size_k
:
torch
.
SymI
nt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
torch
.
float16
,
device
=
a
.
device
)
...
...
@@ -374,16 +376,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::marlin_gemm"
)
def
_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
i
nt
,
size_n
:
i
nt
,
size_k
:
i
nt
)
->
torch
.
Tensor
:
size_m
:
torch
.
SymI
nt
,
size_n
:
torch
.
SymI
nt
,
size_k
:
torch
.
SymI
nt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
torch
.
float16
,
device
=
a
.
device
)
@
register_fake
(
"_C::awq_dequantize"
)
def
_awq_dequantize_fake
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
i
nt
,
thy
:
int
)
->
torch
.
Tensor
:
zeros
:
torch
.
Tensor
,
split_k_iters
:
torch
.
SymI
nt
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
in_c
=
qweight
.
size
(
0
)
qout_c
=
qweight
.
size
(
1
)
out_c
=
qout_c
*
8
...
...
@@ -394,7 +396,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::awq_gemm"
)
def
_awq_gemm_fake
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
i
nt
)
->
torch
.
Tensor
:
split_k_iters
:
torch
.
SymI
nt
)
->
torch
.
Tensor
:
num_in_feats
=
input
.
size
(
0
)
return
torch
.
empty
((
split_k_iters
,
num_in_feats
,
qweight
.
size
(
1
)
*
8
),
dtype
=
input
.
dtype
,
...
...
@@ -429,8 +431,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::fp8_marlin_gemm"
)
def
_fp8_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
,
size_m
:
torch
.
SymInt
,
size_n
:
torch
.
SymInt
,
size_k
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
register_fake
(
"_C::machete_gemm"
)
...
...
@@ -457,40 +460,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return
torch
.
empty_like
(
b_q_weight
,
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::causal_conv1d_fwd"
)
def
causal_conv1d_fwd_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
pad_slot_id
:
int
):
return
None
@
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
)
->
None
:
return
None
@
register_fake
(
"_C::selective_scan_fwd"
)
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
)
->
None
:
return
None
# cutlass
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
...
...
@@ -611,7 +580,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
,
use_fp32_reduce
)
...
...
@@ -627,7 +596,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# machete
def
machete_supported_schedules
(
b_type
:
ScalarType
)
->
List
[
str
]:
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
b_type
)
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
b_type
.
id
)
def
machete_gemm
(
...
...
@@ -642,13 +611,13 @@ def machete_gemm(
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_gemm
(
a
,
b_q
,
b_type
,
b_scales
,
b_zeros
,
return
torch
.
ops
.
_C
.
machete_gemm
(
a
,
b_q
,
b_type
.
id
,
b_scales
,
b_zeros
,
b_group_size
,
c
,
alpha
,
beta
,
schedule
)
def
machete_prepack_B
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
.
id
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
...
...
@@ -862,10 +831,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
topk_ids
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zero_points
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
i
nt
,
size_
k
:
int
,
is_k_full
:
bool
,
num_experts
:
i
nt
,
topk
:
int
,
moe_block_size
:
int
,
replicate_input
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
torch
.
SymI
nt
,
size_
n
:
torch
.
SymInt
,
size_k
:
torch
.
SymI
nt
,
is_k_full
:
bool
,
num_experts
:
int
,
topk
:
int
,
moe_block_size
:
int
,
replicate_input
:
bool
,
apply_weights
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
topk
,
size_n
),
dtype
=
a
.
dtype
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment