Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5288c06a
Unverified
Commit
5288c06a
authored
Aug 20, 2024
by
Lucas Wilkinson
Committed by
GitHub
Aug 20, 2024
Browse files
[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174)
parent
b6f99a6f
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
754 additions
and
2 deletions
+754
-2
csrc/quantization/machete/machete_prepack_kernel.cuh
csrc/quantization/machete/machete_prepack_kernel.cuh
+62
-0
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+71
-0
csrc/quantization/machete/machete_prepacked_layout.cuh
csrc/quantization/machete/machete_prepacked_layout.cuh
+220
-0
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+79
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+15
-0
tests/kernels/test_machete_gemm.py
tests/kernels/test_machete_gemm.py
+272
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+26
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+9
-2
No files found.
csrc/quantization/machete/machete_prepack_kernel.cuh
0 → 100644
View file @
5288c06a
#pragma once
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace
machete
{
template
<
typename
TileShapeNKL
,
typename
ElementB
,
typename
BInTensor
,
typename
BTiledOutTensor
>
static
__global__
void
prepack_B_kernel
(
BInTensor
B_in
,
BTiledOutTensor
B_tiled_out
)
{
auto
tB_in
=
local_tile
(
B_in
,
TileShapeNKL
{},
make_coord
(
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
));
auto
tB_out
=
B_tiled_out
(
make_coord
(
_
,
_
),
make_coord
(
blockIdx
.
x
,
blockIdx
.
y
),
blockIdx
.
z
);
auto
tiled_copy
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementB
>
{},
Layout
<
Shape
<
_4
,
_32
>
,
Stride
<
_32
,
_1
>>
{},
Layout
<
Shape
<
_1
,
_2
>>
{});
auto
thr_copy
=
tiled_copy
.
get_thread_slice
(
threadIdx
.
x
);
Tensor
thr_tile_S
=
thr_copy
.
partition_S
(
tB_in
);
Tensor
thr_tile_D
=
thr_copy
.
partition_D
(
tB_out
);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto
fragment
=
make_tensor
<
ElementB
>
(
shape
(
thr_tile_D
));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy
(
tiled_copy
,
thr_tile_S
,
fragment
);
copy
(
Copy_Atom
<
DefaultCopy
,
uint8_t
>
{},
fragment
,
thr_tile_D
);
}
template
<
typename
PrepackedLayoutB
,
typename
InLayout
>
static
void
prepack_B
(
cudaStream_t
stream
,
typename
PrepackedLayoutB
::
ElementB
const
*
B_in_ptr
,
InLayout
B_layout
,
typename
PrepackedLayoutB
::
ElementB
*
B_out_ptr
)
{
using
TileShapeNKL
=
decltype
(
append
(
typename
PrepackedLayoutB
::
PPBlockShape_NK
{},
_1
{}));
auto
ilvd_NKbNbKL_to_offset
=
PrepackedLayoutB
::
ilvd_NKbNbKL_to_offset
(
shape
(
B_layout
));
TORCH_CHECK
(
size
<
0
>
(
B_layout
)
%
size
<
0
>
(
TileShapeNKL
{})
==
0
);
TORCH_CHECK
(
size
<
1
>
(
B_layout
)
%
size
<
1
>
(
TileShapeNKL
{})
==
0
);
TORCH_CHECK
(
size
<
2
>
(
B_layout
)
%
size
<
2
>
(
TileShapeNKL
{})
==
0
);
auto
N_tiles
=
size
<
0
>
(
B_layout
)
/
size
<
0
>
(
TileShapeNKL
{});
auto
K_tiles
=
size
<
1
>
(
B_layout
)
/
size
<
1
>
(
TileShapeNKL
{});
auto
L_tiles
=
size
<
2
>
(
B_layout
)
/
size
<
2
>
(
TileShapeNKL
{});
auto
B_in
=
make_tensor
(
get_logical_ptr
(
B_in_ptr
),
B_layout
);
auto
B_tiled_out
=
make_tensor
(
get_logical_ptr
(
B_out_ptr
),
ilvd_NKbNbKL_to_offset
);
prepack_B_kernel
<
TileShapeNKL
,
typename
PrepackedLayoutB
::
ElementB
>
<<<
dim3
(
N_tiles
,
K_tiles
,
L_tiles
),
128
,
0
,
stream
>>>
(
B_in
,
B_tiled_out
);
}
};
// namespace machete
\ No newline at end of file
csrc/quantization/machete/machete_prepack_launcher.cuh
0 → 100644
View file @
5288c06a
#pragma once
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace
machete
{
template
<
typename
PrepackedLayoutB
>
torch
::
Tensor
prepack_impl
(
torch
::
Tensor
const
B
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
B
));
using
ElementB
=
typename
PrepackedLayoutB
::
ElementB
;
using
PPBlockShape_NK
=
typename
PrepackedLayoutB
::
PPBlockShape_NK
;
auto
device
=
B
.
device
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
device
.
index
());
auto
B_ptr
=
static_cast
<
ElementB
const
*>
(
B
.
const_data_ptr
());
// elements per storage item for B
auto
eles_per_storage
=
(
B
.
dtype
().
itemsize
()
*
8
)
/
cute
::
sizeof_bits_v
<
ElementB
>
;
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto
Bt_packed
=
B
.
t
();
TORCH_CHECK
(
(
B
.
size
(
0
)
*
eles_per_storage
)
%
size
<
1
>
(
PPBlockShape_NK
{})
==
0
,
"B.shape[0] (in terms of unpacked elements) must be a multiple of "
,
size
<
1
>
(
PPBlockShape_NK
{}));
TORCH_CHECK
(
B
.
size
(
1
)
%
size
<
0
>
(
PPBlockShape_NK
{})
==
0
,
"B.shape[1] must be a multiple of "
,
size
<
0
>
(
PPBlockShape_NK
{}));
using
StrideB
=
cutlass
::
detail
::
TagToStrideB_t
<
cutlass
::
layout
::
ColumnMajor
>
;
auto
const
l_Bt_packed
=
make_cute_layout
<
StrideB
>
(
Bt_packed
,
"B"
);
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
TORCH_CHECK
(
stride
<
1
>
(
l_Bt_packed
)
==
1
);
// clang-format off
auto
const
layout_Bt
=
make_layout
(
transform_with_idx
(
l_Bt_packed
.
shape
(),
[
&
](
auto
ele
,
auto
idx
)
{
return
idx
==
1
?
ele
*
eles_per_storage
:
ele
;
}),
transform_with_idx
(
l_Bt_packed
.
stride
(),
[
&
](
auto
ele
,
auto
idx
)
{
return
idx
!=
1
?
ele
*
eles_per_storage
:
ele
;
}));
// clang-format on
// Allocate output
torch
::
Tensor
D
=
torch
::
empty_like
(
B
);
prepack_B
<
PrepackedLayoutB
>
(
stream
,
B_ptr
,
layout_Bt
,
static_cast
<
ElementB
*>
(
D
.
mutable_data_ptr
()));
return
D
;
};
template
<
typename
ElementA
,
typename
ElementB
,
typename
ElementD
,
typename
AccumulatorT
=
float
,
typename
ScaleT
=
cutlass
::
half_t
,
typename
ZeroT
=
cutlass
::
half_t
>
struct
PrepackBDispatcher
{
static
torch
::
Tensor
dispatch
(
torch
::
Tensor
B
);
};
};
// namespace machete
\ No newline at end of file
csrc/quantization/machete/machete_prepacked_layout.cuh
0 → 100644
View file @
5288c06a
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace
machete
{
using
namespace
cute
;
struct
IlvBlkLayoutAuto
{};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template
<
typename
ElementA_
,
typename
ElementB_
,
typename
ElementD_
,
typename
AccumulatorT
,
class
LayoutB
,
class
KernelSchedule
,
typename
IlvBlkLayout_
=
IlvBlkLayoutAuto
>
// clang-format on
struct
PrepackedLayoutBTemplate
{
using
MmaType
=
ElementA_
;
using
ElementA
=
ElementA_
;
using
ElementB
=
ElementB_
;
using
ElementD
=
ElementD_
;
using
ElementAccumulator
=
AccumulatorT
;
// Element type for internal accumulation
using
ElementMma
=
MmaType
;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// iterleaved layouts
using
IlvdBlkLayout
=
std
::
conditional_t
<
std
::
is_same_v
<
IlvBlkLayout_
,
IlvBlkLayoutAuto
>
,
std
::
conditional_t
<
sizeof_bits_v
<
ElementB
>
<=
4
,
decltype
(
get_interleaved_blk_layout
<
ElementB
,
sizeof_bits_v
<
ElementA
>
,
32
>
()),
void
>
,
IlvBlkLayout_
>
;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using
PPBlockShape_NK
=
Shape
<
_128
,
_64
>
;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using
GemmTileShape
=
decltype
(
make_shape
(
size
<
0
>
(
PPBlockShape_NK
{}),
_128
{},
size
<
1
>
(
PPBlockShape_NK
{})));
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
gmma_rs_tag_to_major_B
<
LayoutB
>
();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
rs_op_selector
<
ElementMma
,
ElementMma
,
ElementAccumulator
,
GemmTileShape
,
GMMA
::
Major
::
K
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE
static
constexpr
auto
ppblock_TV_to_NK
()
{
return
TiledMma
{}.
thrfrg_A
(
make_layout
(
PPBlockShape_NK
{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE
static
constexpr
auto
ppblock_NK_to_TV
()
{
return
right_inverse
(
ppblock_TV_to_NK
()).
with_shape
(
PPBlockShape_NK
{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE
static
constexpr
auto
ppblock_TV_to_offset
()
{
// Return iterleaved layout
return
make_ordered_layout
(
shape
(
ppblock_TV_to_NK
()),
Step
<
_1
,
_0
>
{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE
static
constexpr
auto
ppblock_ilvd_TV_to_offset
()
{
auto
layout_no_interleave
=
make_ordered_layout
(
shape
(
ppblock_TV_to_NK
()),
Step
<
_1
,
_0
>
{});
if
constexpr
(
std
::
is_same_v
<
IlvdBlkLayout
,
void
>
)
{
return
layout_no_interleave
;
}
else
{
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto
frgV
=
get
<
1
,
0
>
(
layout_no_interleave
);
auto
ilvdBlk
=
IlvdBlkLayout
{};
static_assert
(
size
(
frgV
)
%
4
==
0
,
"FrgV must be divisible by 4"
);
auto
ilvd_FrgV
=
make_layout
(
make_shape
(
shape
(
ilvdBlk
),
Int
<
size
(
frgV
)
/
size
(
ilvdBlk
)
>
{}),
make_stride
(
stride
(
ilvdBlk
),
size
(
ilvdBlk
)));
// Return iterleaved layout
return
make_layout
(
get
<
0
>
(
layout_no_interleave
),
make_layout
(
ilvd_FrgV
,
get
<
1
,
1
>
(
layout_no_interleave
)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE
static
constexpr
auto
ppblock_ilvd_NK_to_offset
()
{
// do (M,K) -> (athrid, val) -> (storage_idx)
return
ppblock_ilvd_TV_to_offset
().
compose
(
ppblock_NK_to_TV
());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template
<
typename
Shape_NKL
>
CUTE_HOST_DEVICE
static
constexpr
auto
TVbNbKL_to_offset
(
Shape_NKL
shape_mkl
)
{
constexpr
auto
block_layout
=
ppblock_TV_to_offset
();
// (BlocksN, BlocksK, L)
auto
blocks_shape
=
cute
::
transform
(
shape_mkl
,
append
(
PPBlockShape_NK
{},
_1
{}),
[](
auto
x
,
auto
y
)
{
return
x
/
y
;
});
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto
result
=
make_layout
(
block_layout
,
make_layout
(
blocks_shape
,
compact_col_major
(
blocks_shape
,
size
(
block_layout
))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template
<
typename
Shape_NKL
>
CUTE_HOST_DEVICE
static
constexpr
auto
ilvd_NKbNbKL_to_offset
(
Shape_NKL
shape_mkl
)
{
constexpr
auto
block_layout
=
ppblock_ilvd_NK_to_offset
();
// (BlocksN, BlocksK, L)
auto
blocks_shape
=
cute
::
transform
(
shape_mkl
,
append
(
PPBlockShape_NK
{},
_1
{}),
[](
auto
x
,
auto
y
)
{
return
x
/
y
;
});
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto
result
=
make_layout
(
block_layout
,
make_layout
(
blocks_shape
,
compact_col_major
(
blocks_shape
,
size
(
block_layout
))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return
group
<
1
,
3
>
(
result
(
_
,
repeat
<
rank
<
1
>
(
result
)
>
(
_
)));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template
<
class
Shape_NKL
>
CUTE_HOST_DEVICE
static
auto
TVbNbK_to_NKL
(
Shape_NKL
shape_mkl
)
{
auto
tile
=
make_tile
(
make_layout
(
size
<
0
>
(
PPBlockShape_NK
{})),
make_layout
(
size
<
1
>
(
PPBlockShape_NK
{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto
tiled_A
=
zipped_divide
(
make_layout
(
shape_mkl
),
tile
);
return
tiled_A
.
compose
(
ppblock_TV_to_NK
(),
_
);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template
<
class
Shape_NKL
>
CUTE_HOST_DEVICE
static
auto
NKL_to_TVbNbK
(
Shape_NKL
shape_mkl
)
{
auto
TVbNbK_to_NKL_layout
=
TVbNbK_to_NKL
(
shape_mkl
);
return
blocked_product
(
ppblock_NK_to_TV
(),
make_layout
(
shape
<
1
>
(
TVbNbK_to_NKL_layout
)));
}
};
};
// namespace machete
\ No newline at end of file
csrc/quantization/machete/machete_pytorch.cu
0 → 100644
View file @
5288c06a
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
namespace
machete
{
using
namespace
vllm
;
//
// Utils (type dispatching)
//
template
<
typename
Fn
>
static
auto
scalar_type_dispatch
(
ScalarType
const
&
type
,
Fn
fn
)
{
if
(
type
==
vllm
::
kU4
)
{
return
fn
(
cutlass
::
uint4b_t
{});
}
else
if
(
type
==
vllm
::
kU8
)
{
return
fn
(
cutlass
::
uint8_t
{});
}
else
if
(
type
==
vllm
::
kU4B8
)
{
return
fn
(
cutlass
::
vllm_uint4b8_t
{});
}
else
if
(
type
==
vllm
::
kU8B128
)
{
return
fn
(
cutlass
::
vllm_uint8b128_t
{});
}
else
{
TORCH_CHECK
(
false
,
"Unsupported type "
,
type
.
str
());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std
::
vector
<
std
::
string
>
supported_schedules
(
ScalarTypeTorchPtr
const
&
btype
)
{
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
GemmDispatcher
<
half_t
,
decltype
(
BType
)
>::
supported_schedules
();
});
}
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
)
{
auto
args
=
PyTorchArguments
{.
A
=
A
,
.
B
=
B
,
.
scales
=
scales
,
.
zeros
=
zeros
,
.
group_size
=
group_size
,
.
C
=
C
,
.
alpha
=
alpha
,
.
beta
=
beta
,
.
schedule
=
schedule
};
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
AT_DISPATCH_SUPPORTED_COMPUTE_TYPES
(
A
.
scalar_type
(),
"machete_gemm"
,
[
&
]
{
using
ComputeType
=
equivalent_cutlass_type_t
<
scalar_t
>
;
return
GemmDispatcher
<
ComputeType
,
decltype
(
BType
)
>::
dispatch
(
args
);
});
});
}
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
ScalarTypeTorchPtr
const
&
btype
)
{
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
PrepackBDispatcher
<
half_t
,
decltype
(
BType
),
half_t
>::
dispatch
(
B
);
});
}
};
// namespace machete
csrc/torch_bindings.cpp
View file @
5288c06a
...
@@ -133,6 +133,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -133,6 +133,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
);
ops
.
impl
(
"gptq_marlin_24_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_24_gemm
);
ops
.
impl
(
"gptq_marlin_24_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_24_gemm
);
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops
.
def
(
"machete_supported_schedules"
,
&
machete
::
supported_schedules
);
ops
.
def
(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor"
);
ops
.
impl
(
"machete_gemm"
,
torch
::
kCUDA
,
&
machete
::
gemm
);
ops
.
def
(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor"
);
ops
.
impl
(
"machete_prepack_B"
,
torch
::
kCUDA
,
&
machete
::
prepack_B
);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
ops
.
impl
(
"gptq_marlin_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_gemm
);
ops
.
impl
(
"gptq_marlin_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_gemm
);
...
...
tests/kernels/test_machete_gemm.py
0 → 100644
View file @
5288c06a
"""Tests for the machete kernel.
Run `pytest tests/kernels/test_machete_gemm.py`.
"""
import
math
from
typing
import
Optional
,
Tuple
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
MNK_SHAPES
=
[
(
1
,
128
,
128
),
(
1
,
512
,
1024
),
(
1
,
4096
,
4096
),
(
13
,
8192
,
4096
),
(
26
,
4096
,
8192
),
(
1
,
4096
,
4096
),
(
257
,
128
,
4096
),
(
257
,
4224
,
4160
),
(
257
,
4096
,
4096
),
(
64
,
4096
,
4096
),
]
ACT_TYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
WTYPE_ZEROPOINTS
=
[
# GPTQ style
(
scalar_types
.
uint4b8
,
False
),
(
scalar_types
.
uint8b128
,
False
),
# AWQ style
(
scalar_types
.
uint4
,
True
),
(
scalar_types
.
uint8
,
True
),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU
=
current_platform
.
get_device_capability
()[
0
]
>=
9
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
10
*
(
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.3
)
def
maybe_convert_zeropoints
(
zps
:
Optional
[
torch
.
Tensor
],
s
:
torch
.
Tensor
):
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
def
machete_quantize_and_pack
(
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
=
False
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
,
zero_points
=
zero_points
,
# to match how the kernel applies zps
ref_zero_points_after_scales
=
True
)
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
def
machete_gemm_test_helper
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
wtype
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
):
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
b
,
wtype
,
group_size
,
zero_points
)
output_ref
=
torch
.
matmul
(
a
,
w_ref
)
output
=
ops
.
machete_gemm
(
a
=
a
,
b_q
=
w_q_packed
,
b_type
=
wtype
,
b_scales
=
w_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_group_size
=
group_size
,
)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol
=
1
if
zero_points
else
min
(
5e-2
*
math
.
sqrt
(
a
.
shape
[
1
]),
1
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-1
,
atol
=
atol
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"atype"
,
ACT_TYPES
,
ids
=
lambda
x
:
str
(
x
))
@
pytest
.
mark
.
parametrize
(
"wtype_zeropoints"
,
WTYPE_ZEROPOINTS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
128
,
None
])
def
test_machete_all_schedules
(
shape
,
atype
:
torch
.
dtype
,
wtype_zeropoints
:
Tuple
[
ScalarType
,
bool
],
group_size
:
Optional
[
int
]):
m
,
n
,
k
=
shape
wtype
,
zero_points
=
wtype_zeropoints
if
group_size
is
not
None
and
k
%
group_size
!=
0
:
return
print
(
f
"MNK =
{
m
}
{
n
}
{
k
}
"
)
# Normalize group_size
if
group_size
is
None
:
group_size
=
k
assert
group_size
<=
k
a
=
rand_data
((
m
,
k
),
atype
)
w
=
rand_data
((
k
,
n
),
atype
)
w_ref
,
w_q_machete
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
w
,
wtype
,
group_size
,
zero_points
)
output_ref
=
torch
.
matmul
(
a
,
w_ref
)
for
schedule
in
ops
.
machete_supported_schedules
(
wtype
):
output
=
ops
.
machete_gemm
(
a
,
b_q
=
w_q_machete
,
b_type
=
wtype
,
b_scales
=
w_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_group_size
=
group_size
,
schedule
=
schedule
,
)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol
=
1
if
zero_points
else
min
(
5e-2
*
math
.
sqrt
(
k
),
1
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-1
,
atol
=
atol
),
\
f
"Schedule failed
{
schedule
}
"
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"atype"
,
ACT_TYPES
,
ids
=
lambda
x
:
str
(
x
))
@
pytest
.
mark
.
parametrize
(
"wtype_zeropoints"
,
WTYPE_ZEROPOINTS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
128
,
None
])
def
test_machete_heuristic
(
shape
,
atype
:
torch
.
dtype
,
wtype_zeropoints
:
Tuple
[
ScalarType
,
bool
],
group_size
:
Optional
[
int
]):
m
,
n
,
k
=
shape
wtype
,
zero_points
=
wtype_zeropoints
if
group_size
is
not
None
and
k
%
group_size
!=
0
:
return
# Normalize group_size
if
group_size
is
None
:
group_size
=
k
assert
group_size
<=
k
a
=
rand_data
((
m
,
k
),
atype
)
b
=
rand_data
((
k
,
n
),
atype
)
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
# Test working on other devices
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_machete_devices
(
device
:
str
):
m
,
n
,
k
=
512
,
4096
,
4096
wtype
=
scalar_types
.
uint4b8
group_size
=
128
zero_points
=
False
print
(
f
"MNK =
{
m
}
{
n
}
{
k
}
, device =
{
device
}
"
)
a
=
rand_data
((
m
,
k
),
torch
.
float16
).
to
(
device
)
b
=
rand_data
((
k
,
n
),
torch
.
float16
).
to
(
device
)
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
# Test working with a subset of A and B
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
def
test_machete_subset
():
big_m
,
big_n
,
big_k
=
1024
,
1024
,
1024
m
,
n
,
k
=
512
,
512
,
512
wtype
=
scalar_types
.
uint4b8
group_size
=
128
zero_points
=
False
whole_a
=
rand_data
((
big_m
,
big_k
),
torch
.
float16
)
whole_b
=
rand_data
((
big_k
,
big_n
),
torch
.
float16
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
machete_gemm_test_helper
(
a
,
b
,
wtype
,
group_size
,
zero_points
)
# Test to make sure cuda graphs work
class
MacheteLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
kwargs
=
kwargs
def
forward
(
self
,
a
):
return
ops
.
machete_gemm
(
**
self
.
kwargs
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"Machete is not supported on this GPU type."
)
def
test_machete_cuda_graph
():
m
,
n
,
k
=
512
,
4096
,
4096
a
=
rand_data
((
m
,
k
),
torch
.
float16
)
b
=
rand_data
((
k
,
n
),
torch
.
float16
)
wtype
=
scalar_types
.
uint4b8
group_size
=
128
zero_points
=
False
w_ref
,
w_q_packed
,
w_s
,
w_zp
=
machete_quantize_and_pack
(
b
,
wtype
,
group_size
,
zero_points
)
# Construct a trivial model with a single layer that calls a machete kernel
model
=
MacheteLayer
(
a
=
a
,
b_q
=
w_q_packed
,
b_type
=
wtype
,
b_scales
=
w_s
,
b_zeros
=
maybe_convert_zeropoints
(
w_zp
,
w_s
),
b_group_size
=
group_size
,
)
output_ref
=
torch
.
matmul
(
a
,
w_ref
)
# Run the model with a cuda graph
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
output
=
model
(
a
)
output
.
zero_
()
g
.
replay
()
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol
=
1
if
zero_points
else
min
(
5e-2
*
math
.
sqrt
(
k
),
1
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-1
,
atol
=
atol
)
vllm/_custom_ops.py
View file @
5288c06a
...
@@ -329,6 +329,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -329,6 +329,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
num_bits
,
size_m
,
size_n
,
size_k
)
num_bits
,
size_m
,
size_n
,
size_k
)
# machete
def
machete_supported_schedules
(
b_type
:
ScalarType
)
->
List
[
str
]:
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
b_type
)
def
machete_gemm
(
a
:
torch
.
Tensor
,
b_q
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
c
:
Optional
[
torch
.
Tensor
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_gemm
(
a
,
b_q
,
b_type
,
b_scales
,
b_zeros
,
b_group_size
,
c
,
alpha
,
beta
,
schedule
)
def
machete_prepack_B
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
# fp8
# fp8
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
5288c06a
...
@@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
...
@@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
def
quantize_weights
(
w
:
torch
.
Tensor
,
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
quant_type
:
ScalarType
,
group_size
:
int
,
group_size
:
int
,
zero_points
:
bool
=
False
):
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
"Floating point quantization may work but has not been tested"
...
@@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor,
...
@@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor,
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# Compute ref (dequantized)
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
zero_points
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
w_q
+=
quant_type
.
bias
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment