Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96d999fb
Unverified
Commit
96d999fb
authored
Nov 18, 2024
by
Lucas Wilkinson
Committed by
GitHub
Nov 18, 2024
Browse files
[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
c2170a5b
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
600 additions
and
151 deletions
+600
-151
csrc/quantization/machete/machete_prepacked_layout.cuh
csrc/quantization/machete/machete_prepacked_layout.cuh
+42
-12
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+46
-74
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+29
-6
tests/kernels/test_machete_mm.py
tests/kernels/test_machete_mm.py
+406
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+44
-31
vllm/model_executor/layers/quantization/kernels/machete.py
vllm/model_executor/layers/quantization/kernels/machete.py
+9
-7
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+24
-21
No files found.
csrc/quantization/machete/machete_prepacked_layout.cuh
View file @
96d999fb
...
@@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
...
@@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
// The contract here is that the `TiledMma` determined below matches the one
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
// required along with the kernel schedule)
template
<
typename
ElementA_
,
typename
ElementB_
,
typename
Element
D
_
,
template
<
typename
ElementA_
,
typename
ElementB_
,
typename
Element
Convert
_
,
typename
AccumulatorT
,
class
LayoutB
,
class
KernelSchedule
,
typename
AccumulatorT
,
class
LayoutB
,
class
KernelSchedule
,
typename
IlvBlkLayout_
=
IlvBlkLayoutAuto
>
typename
IlvBlkLayout_
=
IlvBlkLayoutAuto
>
// clang-format on
// clang-format on
...
@@ -49,19 +49,26 @@ struct PrepackedLayoutBTemplate {
...
@@ -49,19 +49,26 @@ struct PrepackedLayoutBTemplate {
using
MmaType
=
ElementA_
;
using
MmaType
=
ElementA_
;
using
ElementA
=
ElementA_
;
using
ElementA
=
ElementA_
;
using
ElementB
=
ElementB_
;
using
ElementB
=
ElementB_
;
using
ElementD
=
ElementD_
;
using
ElementAccumulator
=
AccumulatorT
;
using
ElementAccumulator
=
AccumulatorT
;
// Element type for internal accumulation
using
ElementMma
=
MmaType
;
using
ElementMma
=
MmaType
;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// in those cases case we use a LUT using prmt instructions to upconvert and
// iterleaved layouts
// is more efficient if the data is not interleaved For 8bit+ prmt
// instructions makes non-interleaved layouts efficient enough we don't need
// iterleaved layouts (and can reuse more of the existing cutlass converts)
static
constexpr
bool
should_interleave
=
sizeof_bits_v
<
ElementB
>
<=
4
&&
!
std
::
is_same_v
<
ElementConvert_
,
cutlass
::
float_e4m3_t
>
&&
!
std
::
is_same_v
<
ElementConvert_
,
int8_t
>
;
// Only use interleaved layouts for subbyte weights,
using
IlvdBlkLayout
=
std
::
conditional_t
<
using
IlvdBlkLayout
=
std
::
conditional_t
<
std
::
is_same_v
<
IlvBlkLayout_
,
IlvBlkLayoutAuto
>
,
std
::
is_same_v
<
IlvBlkLayout_
,
IlvBlkLayoutAuto
>
,
std
::
conditional_t
<
sizeof_bits_v
<
ElementB
>
<=
4
,
std
::
conditional_t
<
should_interleave
,
decltype
(
get_interleaved_blk_layout
<
decltype
(
get_interleaved_blk_layout
<
ElementB
,
sizeof_bits_v
<
Element
A
>
,
32
>
()),
ElementB
,
sizeof_bits_v
<
Element
Convert_
>
,
32
>
()),
void
>
,
void
>
,
IlvBlkLayout_
>
;
IlvBlkLayout_
>
;
...
@@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
...
@@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto
frgV
=
get
<
1
,
0
>
(
layout_no_interleave
);
auto
frgV
=
get
<
1
,
0
>
(
layout_no_interleave
);
auto
ilvdBlk
=
IlvdBlkLayout
{};
auto
ilvdBlk
=
IlvdBlkLayout
{};
static_assert
(
size
(
frgV
)
%
4
==
0
,
"FrgV must be divisible by 4"
);
static_assert
(
size
(
frgV
)
%
size
(
ilvdBlk
)
==
0
,
"FrgV must be divisible by size(ilvdBlk)"
);
auto
ilvd_FrgV
=
make_layout
(
auto
ilvd_FrgV
=
make_layout
(
make_shape
(
shape
(
ilvdBlk
),
Int
<
size
(
frgV
)
/
size
(
ilvdBlk
)
>
{}),
make_shape
(
shape
(
ilvdBlk
),
Int
<
size
(
frgV
)
/
size
(
ilvdBlk
)
>
{}),
make_stride
(
stride
(
ilvdBlk
),
size
(
ilvdBlk
)));
make_stride
(
stride
(
ilvdBlk
),
size
(
ilvdBlk
)));
...
@@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
...
@@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
}
}
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
template
<
typename
Shape_NKL
>
CUTE_HOST_DEVICE
static
constexpr
auto
TVbNbKL_to_offset_copy
(
Shape_NKL
shape_mkl
)
{
auto
layout
=
TVbNbKL_to_offset
(
shape_mkl
);
return
make_layout
(
coalesce
(
get
<
0
>
(
layout
)),
get
<
1
>
(
layout
),
get
<
2
>
(
layout
));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template
<
typename
Shape_NKL
>
template
<
typename
Shape_NKL
>
CUTE_HOST_DEVICE
static
constexpr
auto
ilvd_NKbNbKL_to_offset
(
CUTE_HOST_DEVICE
static
constexpr
auto
ilvd_NKbNbKL_to_offset
(
...
@@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
...
@@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
}
}
// (BlocksN, BlocksK, L) -> (storage_idx)
template
<
typename
Shape_NKL
>
CUTE_HOST_DEVICE
static
constexpr
auto
bNbKL_to_offset
(
Shape_NKL
shape_mkl
)
{
// (BlocksN, BlocksK, L)
auto
blocks_shape
=
cute
::
transform
(
shape_mkl
,
append
(
PPBlockShape_NK
{},
_1
{}),
[](
auto
x
,
auto
y
)
{
return
x
/
y
;
});
auto
stride
=
size
(
PPBlockShape_NK
{});
// (BlocksN, BlocksK, L) -> (storage_idx)
return
make_layout
(
blocks_shape
,
compact_col_major
(
blocks_shape
,
stride
));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template
<
class
Shape_NKL
>
template
<
class
Shape_NKL
>
CUTE_HOST_DEVICE
static
auto
TVbNbK_to_NKL
(
Shape_NKL
shape_mkl
)
{
CUTE_HOST_DEVICE
static
auto
TVbNbK_to_NKL
(
Shape_NKL
shape_mkl
)
{
...
...
csrc/quantization/machete/machete_pytorch.cu
View file @
96d999fb
...
@@ -8,89 +8,61 @@ namespace machete {
...
@@ -8,89 +8,61 @@ namespace machete {
using
namespace
vllm
;
using
namespace
vllm
;
//
std
::
vector
<
std
::
string
>
supported_schedules
(
// Utils (type dispatching)
at
::
ScalarType
a_type
,
int64_t
b_type_id
,
//
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
template
<
typename
Fn
>
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
static
auto
scalar_type_dispatch
(
ScalarType
const
&
type
,
Fn
fn
)
{
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
if
(
type
==
vllm
::
kU4
)
{
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
return
fn
(
cutlass
::
uint4b_t
{});
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
}
else
if
(
type
==
vllm
::
kU8
)
{
return
supported_schedules_dispatch
({
return
fn
(
cutlass
::
uint8_t
{});
.
a_type
=
a_type
,
}
else
if
(
type
==
vllm
::
kU4B8
)
{
.
b_type
=
b_type
,
return
fn
(
cutlass
::
vllm_uint4b8_t
{});
.
maybe_group_scales_type
=
maybe_group_scales_type
,
}
else
if
(
type
==
vllm
::
kU8B128
)
{
.
maybe_group_zeros_type
=
maybe_group_zeros_type
,
return
fn
(
cutlass
::
vllm_uint8b128_t
{});
.
maybe_channel_scales_type
=
maybe_channel_scales_type
,
}
else
{
.
maybe_token_scales_type
=
maybe_token_scales_type
,
TORCH_CHECK
(
false
,
"Unsupported type "
,
type
.
str
());
.
maybe_out_type
=
maybe_out_type
,
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std
::
vector
<
std
::
string
>
supported_schedules
(
ScalarTypeId
const
btype_id
)
{
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm
::
ScalarType
b_type
=
ScalarType
::
from_id
(
btype_id
);
return
scalar_type_dispatch
(
b_type
,
[
&
](
auto
BType
)
{
return
GemmDispatcher
<
half_t
,
decltype
(
BType
)
>::
supported_schedules
();
});
});
#else
TORCH_CHECK
(
false
,
"Machete requires CUDA 12.0 or later"
);
#endif
}
}
torch
::
Tensor
ge
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
ScalarTypeId
cons
t
btype_id
,
int64_
t
b
_
type_id
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zero
s
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scale
s
,
c10
::
optional
<
int64_t
>
group_
si
ze
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_
group_ze
ros
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
c10
::
optional
<
std
::
string
>
schedule
)
{
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
c10
::
optional
<
std
::
string
>
maybe_schedule
)
{
ScalarType
const
btype
=
ScalarType
::
from_id
(
btype_id
);
ScalarType
const
b
_
type
=
ScalarType
::
from_id
(
b
_
type_id
);
auto
args
=
PyTorchArguments
{.
A
=
A
,
return
mm_dispatch
(
{.
A
=
A
,
.
B
=
B
,
.
B
=
B
,
.
scales
=
scales
,
.
b_type
=
b_type
,
.
zeros
=
zeros
,
.
maybe_out_type
=
maybe_out_type
,
.
group_size
=
group_size
,
.
maybe_group_scales
=
maybe_group_scales
,
.
C
=
C
,
.
maybe_group_zeros
=
maybe_group_zeros
,
.
alpha
=
alpha
,
.
maybe_group_size
=
maybe_group_size
,
.
beta
=
beta
,
.
maybe_channel_scales
=
maybe_channel_scales
,
.
schedule
=
schedule
};
.
maybe_token_scales
=
maybe_token_scales
,
.
maybe_schedule
=
maybe_schedule
});
return
scalar_type_dispatch
(
btype
,
[
&
](
auto
BType
)
{
return
AT_DISPATCH_SUPPORTED_COMPUTE_TYPES
(
A
.
scalar_type
(),
"machete_gemm"
,
[
&
]
{
using
ComputeType
=
equivalent_cutlass_type_t
<
scalar_t
>
;
return
GemmDispatcher
<
ComputeType
,
decltype
(
BType
)
>::
dispatch
(
args
);
});
});
#else
TORCH_CHECK
(
false
,
"Machete requires CUDA 12.0 or later"
);
#endif
}
}
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
ScalarTypeId
const
btype_id
)
{
torch
::
Tensor
prepack_B
(
ScalarType
const
btype
=
ScalarType
::
from_id
(
btype_id
);
torch
::
Tensor
const
&
B
,
at
::
ScalarType
const
&
a_type
,
int64_t
b_type_id
,
return
scalar_type_dispatch
(
btype
,
[
&
](
auto
BType
)
{
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
return
PrepackBDispatcher
<
half_t
,
decltype
(
BType
),
half_t
>::
dispatch
(
B
);
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
});
return
prepack_B_dispatch
(
{.
B
=
B
,
.
a_type
=
a_type
,
.
b_type
=
b_type
,
.
maybe_group_scales_type
=
maybe_group_scales_type
});
}
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"machete_prepack_B"
,
&
prepack_B
);
m
.
impl
(
"machete_prepack_B"
,
&
prepack_B
);
m
.
impl
(
"machete_
ge
mm"
,
&
ge
mm
);
m
.
impl
(
"machete_mm"
,
&
mm
);
}
}
// use CatchAll since supported_schedules has no tensor arguments
// use CatchAll since supported_schedules has no tensor arguments
...
...
csrc/torch_bindings.cpp
View file @
96d999fb
...
@@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl in source file
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops
.
def
(
"machete_supported_schedules(int btype) -> str[]"
);
ops
.
def
(
ops
.
def
(
"machete_gemm(Tensor A, Tensor B, int btype, "
"machete_supported_schedules("
" Tensor? scales, Tensor? zeros, int? group_size, "
" ScalarType a_type,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
" int b_type,"
"-> Tensor"
);
" ScalarType? maybe_group_scales_type,"
ops
.
def
(
"machete_prepack_B(Tensor B, int btype) -> Tensor"
);
" ScalarType? maybe_group_zeros_type,"
" ScalarType? maybe_channel_scales_type,"
" ScalarType? maybe_token_scales_type,"
" ScalarType? maybe_out_type"
") -> str[]"
);
ops
.
def
(
"machete_mm("
" Tensor A,"
" Tensor B,"
" int b_type,"
" ScalarType? out_type,"
" Tensor? group_scales,"
" Tensor? group_zeros,"
" int? group_size,"
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
") -> Tensor"
);
ops
.
def
(
"machete_prepack_B("
" Tensor B,"
" ScalarType a_type,"
" int b_type,"
" ScalarType? group_scales_type"
") -> Tensor"
);
// conditionally compiled so impl registration is in source file
// conditionally compiled so impl registration is in source file
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
...
...
tests/kernels/test_machete_
ge
mm.py
→
tests/kernels/test_machete_mm.py
View file @
96d999fb
"""Tests for the machete kernel.
"""Tests for the machete kernel.
Run `pytest tests/kernels/test_machete_
ge
mm.py`.
Run `pytest tests/kernels/test_machete_mm.py`.
"""
"""
import
math
import
math
from
typing
import
Optional
,
Tuple
from
dataclasses
import
dataclass
,
fields
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -20,6 +21,13 @@ CUDA_DEVICES = [
...
@@ -20,6 +21,13 @@ CUDA_DEVICES = [
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU
=
current_platform
.
get_device_capability
()[
0
]
>=
9
MNK_SHAPES
=
[
MNK_SHAPES
=
[
(
1
,
128
,
128
),
(
1
,
128
,
128
),
(
1
,
512
,
1024
),
(
1
,
512
,
1024
),
...
@@ -36,14 +44,76 @@ MNK_SHAPES = [
...
@@ -36,14 +44,76 @@ MNK_SHAPES = [
(
1024
,
8192
,
4096
),
(
1024
,
8192
,
4096
),
]
]
ACT_TYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
GROUP_SIZES_TO_TEST
:
List
[
Optional
[
int
]]
=
[
128
,
-
1
]
WTYPE_ZEROPOINTS
=
[
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
Optional
[
torch
.
dtype
]
@
dataclass
class
Tensors
:
w_ref
:
torch
.
Tensor
a_ref
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
w_g_s
:
Optional
[
torch
.
Tensor
]
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
Optional
[
torch
.
Tensor
]
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple
=
Tuple
[
List
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
bool
]
TEST_TYPES
=
[
# GPTQ style
# GPTQ style
(
scalar_types
.
uint4b8
,
False
),
*
(
TypeConfig
(
act_type
=
a_type
,
(
scalar_types
.
uint8b128
,
False
),
weight_type
=
w_type
,
output_type
=
None
,
group_scale_type
=
a_type
,
group_zero_type
=
None
,
channel_scale_type
=
None
,
token_scale_type
=
None
)
for
w_type
in
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
for
a_type
in
[
torch
.
float16
,
torch
.
bfloat16
]),
# AWQ style
# AWQ style
(
scalar_types
.
uint4
,
True
),
*
(
TypeConfig
(
act_type
=
a_type
,
(
scalar_types
.
uint8
,
True
),
weight_type
=
w_type
,
output_type
=
None
,
group_scale_type
=
a_type
,
group_zero_type
=
a_type
,
channel_scale_type
=
None
,
token_scale_type
=
None
)
for
w_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
for
a_type
in
[
torch
.
float16
,
torch
.
bfloat16
]),
# QQQ style
*
(
TypeConfig
(
act_type
=
torch
.
int8
,
weight_type
=
scalar_types
.
uint4b8
,
output_type
=
torch
.
float16
,
group_scale_type
=
group_scale_type
,
group_zero_type
=
None
,
channel_scale_type
=
torch
.
float
,
token_scale_type
=
torch
.
float
)
for
group_scale_type
in
[
None
,
torch
.
float16
]),
*
(
TypeConfig
(
act_type
=
torch
.
float8_e4m3fn
,
weight_type
=
scalar_types
.
uint4b8
,
output_type
=
torch
.
float16
,
group_scale_type
=
group_scale_type
,
group_zero_type
=
None
,
channel_scale_type
=
torch
.
float
,
token_scale_type
=
torch
.
float
)
for
group_scale_type
in
[
None
,
torch
.
float16
]),
]
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
...
@@ -54,116 +124,136 @@ WTYPE_ZEROPOINTS = [
...
@@ -54,116 +124,136 @@ WTYPE_ZEROPOINTS = [
IS_SUPPORTED_BY_GPU
=
current_platform
.
has_device_capability
(
90
)
IS_SUPPORTED_BY_GPU
=
current_platform
.
has_device_capability
(
90
)
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
,
offset
=
0
):
return
10
*
(
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.3
)
if
dtype
.
is_floating_point
:
return
(
scale
*
torch
.
rand
(
shape
,
device
=
"cuda"
)
-
offset
).
to
(
dtype
)
else
:
return
torch
.
randint
(
-
8
,
7
,
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
def
maybe_convert_zeropoints
(
zps
:
Optional
[
torch
.
Tensor
],
s
:
torch
.
Tensor
):
def
maybe_convert_zeropoints
(
zps
:
Optional
[
torch
.
Tensor
],
s
:
torch
.
Tensor
):
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
def
machete_quantize_and_pack
(
w
:
torch
.
Tensor
,
def
group_size_valid
(
shape
:
Tuple
[
int
,
int
,
int
],
group_size
:
Optional
[
int
])
->
bool
:
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
def
machete_quantize_and_pack
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
wtype
:
ScalarType
,
group_size
:
int
,
stype
:
Optional
[
torch
.
dtype
],
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
):
zero_points
:
bool
=
False
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
w
,
wtype
,
wtype
,
group_size
,
group_size
=
group_size
,
zero_points
=
zero_points
,
zero_points
=
zero_points
,
# to match how the kernel applies zps
# to match how the kernel applies zps
ref_zero_points_after_scales
=
True
)
ref_zero_points_after_scales
=
True
)
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
.
id
))
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
atype
,
wtype
,
stype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
atype
,
wtype
.
id
,
stype
))
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
def
machete_gemm_test_helper
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
def
create_test_tensors
(
shape
:
Tuple
[
int
,
int
,
int
],
wtype
:
ScalarType
,
group_size
:
int
,
types
:
TypeConfig
,
zero_points
:
bool
):
group_size
:
Optional
[
int
],
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
subset_stride_factor
:
Optional
[
int
]
=
None
)
->
Tensors
:
b
,
wtype
,
group_size
,
zero_points
)
m
,
n
,
k
=
shape
factor
=
subset_stride_factor
or
1
output_ref
=
torch
.
matmul
(
a
,
w_ref
)
output
=
ops
.
machete_gemm
(
a
=
a
,
b_q
=
w_q_packed
,
b_type
=
wtype
,
b_scales
=
w_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_group_size
=
group_size
,
)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol
=
1
if
zero_points
else
min
(
5e-2
*
math
.
sqrt
(
a
.
shape
[
1
]),
1
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-1
,
atol
=
atol
)
print
(
"create_test_tensors, shape:"
,
shape
,
"types:"
,
types
,
"group_size:"
,
group_size
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
a
=
rand_data
((
m
*
factor
,
k
*
factor
),
types
.
act_type
,
scale
=
3
,
offset
=
2
)
reason
=
"Machete is not supported on this GPU type."
)
w
=
rand_data
((
k
*
factor
,
n
*
factor
),
types
.
act_type
,
scale
=
3
,
offset
=
1
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"atype"
,
ACT_TYPES
,
ids
=
lambda
x
:
str
(
x
))
@
pytest
.
mark
.
parametrize
(
"wtype_zeropoints"
,
WTYPE_ZEROPOINTS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
128
,
None
])
def
test_machete_all_schedules
(
shape
,
atype
:
torch
.
dtype
,
wtype_zeropoints
:
Tuple
[
ScalarType
,
bool
],
group_size
:
Optional
[
int
]):
m
,
n
,
k
=
shape
wtype
,
zero_points
=
wtype_zeropoints
if
group_size
is
not
None
and
k
%
group_size
!=
0
:
if
factor
>
1
:
return
a
=
a
[
0
:
m
,
0
:
k
]
w
=
w
[
0
:
k
,
0
:
n
]
print
(
f
"MNK =
{
m
}
{
n
}
{
k
}
"
)
if
types
.
group_scale_type
is
not
None
:
w
=
w
.
to
(
types
.
group_scale_type
)
if
w
.
dtype
.
itemsize
==
1
:
w
=
w
.
to
(
torch
.
float16
)
# Normalize group_size
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
if
group_size
is
None
:
a
.
dtype
,
w
,
types
.
weight_type
,
types
.
group_scale_type
,
group_size
,
group_size
=
k
types
.
group_zero_type
is
not
None
)
assert
group_size
<=
k
a
=
rand_data
((
m
,
k
),
atype
)
if
not
a
.
dtype
.
is_floating_point
:
w
=
rand_data
((
k
,
n
),
atype
)
aiinfo
=
torch
.
iinfo
(
a
.
dtype
)
w_ref
=
w_ref
.
round
().
clamp
(
aiinfo
.
min
,
aiinfo
.
max
)
w
_ref
,
w_q_machete
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
a
_ref
=
a
.
to
(
torch
.
float32
)
w
,
wtype
,
group_size
,
zero_points
)
w
_ref
=
w_ref
.
to
(
torch
.
float32
)
output_ref
=
torch
.
matmul
(
a
,
w_ref
)
w_ch_s
=
None
if
types
.
channel_scale_type
is
None
else
\
rand_data
((
n
,),
types
.
channel_scale_type
)
w_tok_s
=
None
if
types
.
token_scale_type
is
None
else
\
rand_data
((
m
,),
types
.
token_scale_type
)
for
schedule
in
ops
.
machete_supported_schedules
(
wtype
):
return
Tensors
(
w_ref
=
w_ref
,
print
(
f
"Testing schedule
{
schedule
}
"
)
a_ref
=
a_ref
,
output
=
ops
.
machete_gemm
(
a
=
a
,
a
,
w_q
=
w_q_packed
,
b_q
=
w_q_machete
,
w_g_s
=
w_s
,
b_type
=
wtype
,
w_g_zp
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_scales
=
w_s
,
w_ch_s
=
w_ch_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
w_tok_s
=
w_tok_s
)
# None stype means scales use the same dtype as a
def
machete_mm_test_helper
(
types
:
TypeConfig
,
tensors
:
Tensors
,
group_size
:
Optional
[
int
]
=
None
,
schedule
:
Optional
[
str
]
=
None
):
output_ref
=
torch
.
matmul
(
tensors
.
a_ref
,
tensors
.
w_ref
)
output_ref_type
=
output_ref
.
dtype
if
tensors
.
w_ch_s
is
not
None
:
output_ref
=
(
output_ref
.
to
(
tensors
.
w_ch_s
.
dtype
)
*
tensors
.
w_ch_s
.
unsqueeze
(
0
)).
to
(
output_ref_type
)
if
tensors
.
w_tok_s
is
not
None
:
output_ref
=
(
output_ref
.
to
(
tensors
.
w_tok_s
.
dtype
)
*
tensors
.
w_tok_s
.
unsqueeze
(
1
)).
to
(
output_ref_type
)
output
=
ops
.
machete_mm
(
a
=
tensors
.
a
,
b_q
=
tensors
.
w_q
,
b_type
=
types
.
weight_type
,
b_group_scales
=
tensors
.
w_g_s
,
b_group_zeros
=
tensors
.
w_g_zp
,
b_group_size
=
group_size
,
b_group_size
=
group_size
,
b_channel_scales
=
tensors
.
w_ch_s
,
a_token_scales
=
tensors
.
w_tok_s
,
out_type
=
types
.
output_type
,
schedule
=
schedule
,
schedule
=
schedule
,
)
)
opcheck
(
print
(
output
)
torch
.
ops
.
_C
.
machete_gemm
,
print
(
output_ref
)
(
a
,
w_q_machete
,
wtype
.
id
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
# zeropoints (after scales) causes noise around 0
atol
=
1
if
zero_points
else
min
(
5e-2
*
math
.
sqrt
(
k
),
1
)
atol
=
1
if
tensors
.
w_g_zp
is
not
None
\
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-1
,
atol
=
atol
),
\
else
min
(
5e-2
*
math
.
sqrt
(
tensors
.
a
.
shape
[
1
]),
1
)
f
"Schedule failed
{
schedule
}
"
rtol
=
1e-1
if
tensors
.
a
.
element_size
()
>=
2
else
2e-1
torch
.
testing
.
assert_close
(
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
...
@@ -171,27 +261,50 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
...
@@ -171,27 +261,50 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
@
pytest
.
mark
.
parametrize
(
"shape"
,
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
,
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"atype"
,
ACT_TYPES
,
ids
=
lambda
x
:
str
(
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
@
pytest
.
mark
.
parametrize
(
"wtype_zeropoints"
,
WTYPE_ZEROPOINTS
)
def
test_machete_all_schedules
(
shape
,
types
:
TypeConfig
):
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
128
,
None
])
def
test_machete_heuristic
(
shape
,
atype
:
torch
.
dtype
,
group_sizes
:
List
[
Optional
[
int
]]
=
[]
wtype_zeropoints
:
Tuple
[
ScalarType
,
bool
],
if
types
.
group_scale_type
is
None
:
group_size
:
Optional
[
int
]):
group_sizes
=
[
None
]
m
,
n
,
k
=
shape
else
:
wtype
,
zero_points
=
wtype_zeropoints
group_sizes
=
GROUP_SIZES_TO_TEST
for
group_size
in
group_sizes
:
if
not
group_size_valid
(
shape
,
group_size
):
continue
tensors
=
create_test_tensors
(
shape
,
types
,
group_size
)
print
(
f
"MNK =
{
shape
}
"
)
for
schedule
in
ops
.
machete_supported_schedules
(
types
.
act_type
,
types
.
weight_type
,
group_scales_type
=
types
.
group_scale_type
,
group_zeros_type
=
types
.
group_scale_type
,
out_type
=
types
.
output_type
):
print
(
f
"Testing schedule
{
schedule
}
"
)
machete_mm_test_helper
(
types
,
tensors
,
group_size
,
schedule
)
if
group_size
is
not
None
and
k
%
group_size
!=
0
:
return
# Normalize group_size
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
if
group_size
is
None
:
reason
=
"Machete is not supported on this GPU type."
)
group_size
=
k
@
pytest
.
mark
.
parametrize
(
"shape"
,
assert
group_size
<=
k
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_heuristic
(
shape
,
types
:
TypeConfig
):
group_sizes
:
List
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
else
:
group_sizes
=
GROUP_SIZES_TO_TEST
a
=
rand_data
((
m
,
k
),
atype
)
for
group_size
in
group_sizes
:
b
=
rand_data
((
k
,
n
),
atype
)
if
not
group_size_valid
(
shape
,
group_size
):
continue
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
tensors
=
create_test_tensors
(
shape
,
types
,
group_size
)
machete_mm_test_helper
(
types
,
tensors
,
group_size
)
# Test working on other devices
# Test working on other devices
...
@@ -199,36 +312,45 @@ def test_machete_heuristic(shape, atype: torch.dtype,
...
@@ -199,36 +312,45 @@ def test_machete_heuristic(shape, atype: torch.dtype,
reason
=
"Machete is not supported on this GPU type."
)
reason
=
"Machete is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_machete_devices
(
device
:
str
):
def
test_machete_devices
(
device
:
str
):
m
,
n
,
k
=
512
,
4096
,
4096
wtype
=
scalar_types
.
uint4b8
group_size
=
128
group_size
=
128
zero_points
=
False
print
(
f
"MNK =
{
m
}
{
n
}
{
k
}
, device =
{
device
}
"
)
type_config
=
TypeConfig
(
act_type
=
torch
.
float16
,
weight_type
=
scalar_types
.
uint4b8
,
output_type
=
None
,
group_scale_type
=
torch
.
float16
,
group_zero_type
=
None
,
channel_scale_type
=
None
,
token_scale_type
=
None
)
a
=
rand_data
((
m
,
k
),
torch
.
float16
).
to
(
device
)
tensors
=
create_test_tensors
((
512
,
4096
,
4096
),
type_config
,
group_size
)
b
=
rand_data
((
k
,
n
),
torch
.
float16
).
to
(
device
)
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
for
field
in
fields
(
Tensors
):
tensor
=
getattr
(
tensors
,
field
.
name
)
if
isinstance
(
tensor
,
torch
.
Tensor
):
setattr
(
tensors
,
field
.
name
,
tensor
.
to
(
device
))
machete_mm_test_helper
(
type_config
,
tensors
,
group_size
)
# Test working with a subset of A and B
# Test working with a subset of A and B
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
reason
=
"Machete is not supported on this GPU type."
)
def
test_machete_subset
():
def
test_machete_subset
():
big_m
,
big_n
,
big_k
=
1024
,
1024
,
1024
m
,
n
,
k
=
512
,
512
,
512
wtype
=
scalar_types
.
uint4b8
group_size
=
128
group_size
=
128
zero_points
=
False
whole_a
=
rand_data
((
big_m
,
big_k
),
torch
.
float16
)
type_config
=
TypeConfig
(
act_type
=
torch
.
float16
,
whole_b
=
rand_data
((
big_k
,
big_n
),
torch
.
float16
)
weight_type
=
scalar_types
.
uint4b8
,
output_type
=
None
,
group_scale_type
=
torch
.
float16
,
group_zero_type
=
None
,
channel_scale_type
=
None
,
token_scale_type
=
None
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
tensors
=
create_test_tensors
((
512
,
4096
,
4096
),
b
=
whole_b
[
0
:
k
,
0
:
n
]
type_config
,
group_size
,
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
subset_stride_factor
=
2
)
machete_mm_test_helper
(
type_config
,
tensors
,
group_size
)
# Test to make sure cuda graphs work
# Test to make sure cuda graphs work
...
@@ -239,7 +361,7 @@ class MacheteLayer(torch.nn.Module):
...
@@ -239,7 +361,7 @@ class MacheteLayer(torch.nn.Module):
self
.
kwargs
=
kwargs
self
.
kwargs
=
kwargs
def
forward
(
self
,
a
):
def
forward
(
self
,
a
):
return
ops
.
machete_
ge
mm
(
**
self
.
kwargs
)
return
ops
.
machete_mm
(
a
=
a
,
**
self
.
kwargs
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
...
@@ -250,19 +372,19 @@ def test_machete_cuda_graph():
...
@@ -250,19 +372,19 @@ def test_machete_cuda_graph():
a
=
rand_data
((
m
,
k
),
torch
.
float16
)
a
=
rand_data
((
m
,
k
),
torch
.
float16
)
b
=
rand_data
((
k
,
n
),
torch
.
float16
)
b
=
rand_data
((
k
,
n
),
torch
.
float16
)
wtype
=
scalar_types
.
uint4b8
wtype
=
scalar_types
.
uint4b8
stype
=
torch
.
float16
group_size
=
128
group_size
=
128
zero_points
=
False
zero_points
=
False
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
b
,
w
type
,
group_size
,
zero_points
)
a
.
dtype
,
b
,
wtype
,
s
type
,
group_size
,
zero_points
)
# Construct a trivial model with a single layer that calls a machete kernel
# Construct a trivial model with a single layer that calls a machete kernel
model
=
MacheteLayer
(
model
=
MacheteLayer
(
a
=
a
,
b_q
=
w_q_packed
,
b_q
=
w_q_packed
,
b_type
=
wtype
,
b_type
=
wtype
,
b_scales
=
w_s
,
b_
group_
scales
=
w_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_
group_
zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_group_size
=
group_size
,
b_group_size
=
group_size
,
)
)
...
...
vllm/_custom_ops.py
View file @
96d999fb
...
@@ -444,18 +444,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -444,18 +444,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_k
:
torch
.
SymInt
)
->
torch
.
Tensor
:
size_k
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
register_fake
(
"_C::machete_
ge
mm"
)
@
register_fake
(
"_C::machete_mm"
)
def
machete_
ge
mm_fake
(
def
machete_mm_fake
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
#
b_q
Should be the tensor returned by machete_prepack_B
b_q
:
torch
.
Tensor
,
b_q
:
torch
.
Tensor
,
b_type
:
ScalarType
,
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
c
:
Optional
[
torch
.
Tensor
]
=
None
,
b_channel_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
a_token_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
m
=
a
.
size
(
0
)
m
=
a
.
size
(
0
)
...
@@ -463,8 +463,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -463,8 +463,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::machete_prepack_B"
)
@
register_fake
(
"_C::machete_prepack_B"
)
def
machete_prepack_B_fake
(
b_q_weight
:
torch
.
Tensor
,
def
machete_prepack_B_fake
(
b_type
:
ScalarType
)
->
torch
.
Tensor
:
b_q_weight
:
torch
.
Tensor
,
a_type
:
torch
.
dtype
,
b_type
:
ScalarType
,
group_scales_type
:
Optional
[
torch
.
dtype
])
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b_q_weight
,
return
torch
.
empty_like
(
b_q_weight
,
memory_format
=
torch
.
contiguous_format
)
memory_format
=
torch
.
contiguous_format
)
...
@@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# machete
# machete
def
machete_supported_schedules
(
b_type
:
ScalarType
)
->
List
[
str
]:
def
machete_supported_schedules
(
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
b_type
.
id
)
a_type
:
torch
.
dtype
,
b_type
:
ScalarType
,
group_scales_type
:
Optional
[
torch
.
dtype
],
group_zeros_type
:
Optional
[
torch
.
dtype
]
=
None
,
channel_scales_type
:
Optional
[
torch
.
dtype
]
=
None
,
token_scales_type
:
Optional
[
torch
.
dtype
]
=
None
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
)
->
List
[
str
]:
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
a_type
,
b_type
.
id
,
group_scales_type
,
group_zeros_type
,
channel_scales_type
,
token_scales_type
,
out_type
)
def
machete_
ge
mm
(
def
machete_mm
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b_q
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
# b_q Should be the tensor returned by machete_prepack_B
b_q
:
torch
.
Tensor
,
b_type
:
ScalarType
,
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
c
:
Optional
[
torch
.
Tensor
]
=
None
,
b_channel_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
a_token_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
)
->
torch
.
Tensor
:
schedule
:
Optional
[
str
]
=
None
,
return
torch
.
ops
.
_C
.
machete_mm
(
a
,
b_q
,
b_type
.
id
,
out_type
,
b_group_scales
,
)
->
torch
.
Tensor
:
b_group_zeros
,
b_group_size
,
return
torch
.
ops
.
_C
.
machete_gemm
(
a
,
b_q
,
b_type
.
id
,
b_scales
,
b_zeros
,
b_channel_scales
,
a_token_scales
,
schedule
)
b_group_size
,
c
,
alpha
,
beta
,
schedule
)
def
machete_prepack_B
(
def
machete_prepack_B
(
b_q_weight
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
a_type
:
torch
.
dtype
,
b_type
:
ScalarType
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
group_scales_type
:
Optional
[
torch
.
dtype
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
.
id
)
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
a_type
,
b_type
.
id
,
group_scales_type
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
...
...
vllm/model_executor/layers/quantization/kernels/machete.py
View file @
96d999fb
...
@@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
c
.
weight_type
,
c
.
weight_type
,
packed_dim
=
0
)
packed_dim
=
0
)
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
self
.
config
.
weight_type
)
a_type
=
c
.
act_type
,
b_type
=
c
.
weight_type
,
group_scales_type
=
c
.
act_type
)
return
x
return
x
def
transform_w_s
(
x
):
def
transform_w_s
(
x
):
...
@@ -105,11 +107,11 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -105,11 +107,11 @@ class MacheteLinearKernel(MPLinearKernel):
if
c
.
has_g_idx
:
if
c
.
has_g_idx
:
x_2d
=
self
.
act_perm
(
x_2d
)
x_2d
=
self
.
act_perm
(
x_2d
)
output
=
ops
.
machete_
ge
mm
(
a
=
x_2d
,
output
=
ops
.
machete_mm
(
a
=
x_2d
,
b_q
=
w_q
,
b_q
=
w_q
,
b_type
=
c
.
weight_type
,
b_type
=
c
.
weight_type
,
b
_zeros
=
None
,
b_group
_zeros
=
None
,
b
_scales
=
w_s
,
b_group
_scales
=
w_s
,
b_group_size
=
c
.
group_size
)
b_group_size
=
c
.
group_size
)
if
bias
is
not
None
:
if
bias
is
not
None
:
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
96d999fb
...
@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
...
@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
def
quantize_weights
(
w
:
torch
.
Tensor
,
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
quant_type
:
ScalarType
,
group_size
:
int
,
group_size
:
Optional
[
int
]
,
zero_points
:
bool
=
False
,
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
):
ref_zero_points_after_scales
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
"Floating point quantization may work but has not been tested"
assert
not
zero_points
or
group_size
is
not
None
,
\
"to have group zero points, group_size must be provided "
\
"(-1 group_size is channelwise)"
orig_device
=
w
.
device
orig_device
=
w
.
device
orig_type
=
w
.
dtype
orig_type
=
w
.
dtype
...
@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
...
@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
if
group_size
==
-
1
:
if
group_size
==
-
1
:
group_size
=
size_k
group_size
=
size_k
assert
group_size
<=
size_k
# Reshape to [groupsize, -1]
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
if
group_size
is
not
None
and
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
w
=
w
.
reshape
((
group_size
,
-
1
))
...
@@ -155,6 +157,9 @@ def quantize_weights(w: torch.Tensor,
...
@@ -155,6 +157,9 @@ def quantize_weights(w: torch.Tensor,
max_q_val
=
quant_type
.
max
()
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
min_q_val
=
quant_type
.
min
()
w_s
=
torch
.
Tensor
([
1.0
]).
to
(
w
.
device
)
# unscaled case
maybe_w_zp
=
None
if
group_size
is
not
None
:
if
zero_points
:
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
...
@@ -166,7 +171,6 @@ def quantize_weights(w: torch.Tensor,
...
@@ -166,7 +171,6 @@ def quantize_weights(w: torch.Tensor,
w_s
=
torch
.
max
(
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)))
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)))
maybe_w_zp
=
None
# Quantize
# Quantize
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
...
@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
...
@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
# For some kernels (namely Machete) the zero-points are applied after the
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
zero_points
:
if
ref_zero_points_after_scales
and
maybe_w_zp
is
not
None
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
...
@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
...
@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
w_q
+=
quant_type
.
bias
w_q
+=
quant_type
.
bias
# Restore original shapes
# Restore original shapes
if
group_size
<
size_k
:
if
group_size
is
not
None
and
group_size
<
size_k
:
def
reshape_w
(
w
):
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
...
@@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
...
@@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
w_q
=
reshape_w
(
w_q
)
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_ref
=
reshape_w
(
w_ref
)
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
zero_points
:
if
maybe_w_zp
is
not
None
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
return
(
w_ref
.
to
(
device
=
orig_device
),
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
.
to
(
device
=
orig_device
)
,
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
maybe_w_zp
,
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment