Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7c1b7f3
Commit
e7c1b7f3
authored
Sep 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.4-dtk24.04.1'
parents
7462218e
04c62b93
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4188 additions
and
666 deletions
+4188
-666
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+0
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+1
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+0
-23
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+148
-90
csrc/quantization/cutlass_w8a8/common.hpp
csrc/quantization/cutlass_w8a8/common.hpp
+15
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+86
-297
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+340
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
+123
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
+139
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
+368
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
+353
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+246
-61
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+36
-10
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+2
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+155
-28
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+1305
-0
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+3
-0
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+269
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+572
-122
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+27
-33
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
csrc/quantization/aqlm/gemm_kernels.cu
View file @
e7c1b7f3
...
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
__syncthreads
();
float
res
=
0
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
...
...
csrc/quantization/awq/dequantize.cuh
View file @
e7c1b7f3
...
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return
result
;
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// namespace awq
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
e7c1b7f3
...
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace
vllm
{
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
...
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
scaling_factors_shared
[
N
];
__shared__
half
zeros_shared
[
N
];
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
...
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
...
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
View file @
e7c1b7f3
...
...
@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
...
...
@@ -73,14 +71,12 @@ template<
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
static_assert
(
Stages
==
0
,
"Row broadcast doesn't support smem usage"
);
static_assert
(
is_static_v
<
decltype
(
take
<
0
,
2
>
(
StrideMNL
{}))
>
);
// batch stride can be dynamic or static
static_assert
(
take
<
0
,
2
>
(
StrideMNL
{})
==
Stride
<
_0
,
_1
>
{});
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
struct
SharedStorage
{
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
>
smem
;
};
// This struct has been modified to have a bool indicating that ptr_row is a
...
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
,
smem
_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem
_row
.
data
()))
{
}
:
params
(
params
)
,
smem
(
const_cast
<
Element
*>
(
shared_storage
.
smem
.
data
()))
{
}
Params
params
;
Element
*
smem
_row
;
Element
*
smem
=
nullptr
;
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
tru
e
;
return
fals
e
;
}
CUTLASS_DEVICE
bool
...
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
params
.
ptr_row
==
nullptr
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
return
EmptyProducerLoadCallbacks
{};
}
template
<
int
EpiTiles
,
class
R
Tensor
,
class
STensor
>
template
<
class
GS_GTensor
,
class
GS_STensor
,
class
GS_CTensor
,
class
Tiled_G2S
,
class
SR_S
Tensor
,
class
S
R_R
Tensor
,
class
CTensor
,
class
ThrResidue
,
class
ThrNum
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
params
(
params
)
{}
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
ConsumerStoreCallbacks
(
GS_GTensor
tGS_gRow_
,
GS_STensor
tGS_sRow_
,
GS_CTensor
tGS_cRow_
,
Tiled_G2S
tiled_g2s_
,
SR_STensor
tSR_sRow_
,
SR_RTensor
tSR_rRow_
,
CTensor
tCcRow_
,
ThrResidue
residue_tCcRow_
,
ThrNum
thr_num_
,
Params
const
&
params_
)
:
tGS_gRow
(
tGS_gRow_
)
,
tGS_sRow
(
tGS_sRow_
)
,
tGS_cRow
(
tGS_cRow_
)
,
tiled_G2S
(
tiled_g2s_
)
,
tSR_sRow
(
tSR_sRow_
)
,
tSR_rRow
(
tSR_rRow_
)
,
tCcRow
(
tCcRow_
)
,
residue_tCcRow
(
residue_tCcRow_
)
,
params
(
params_
)
{}
GS_GTensor
tGS_gRow
;
// (CPY,CPY_M,CPY_N)
GS_STensor
tGS_sRow
;
// (CPY,CPY_M,CPY_N)
GS_CTensor
tGS_cRow
;
// (CPY,CPY_M,CPY_N)
Tiled_G2S
tiled_G2S
;
SR_STensor
tSR_sRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor
tSR_rRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor
tCcRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue
residue_tCcRow
;
// (m, n)
ThrNum
thr_num
;
Params
const
&
params
;
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
begin
(
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
t
C
rRow
,
*
(
params
.
ptr_row
));
fill
(
t
SR_
rRow
,
*
(
params
.
ptr_row
));
return
;
}
auto
synchronize
=
[
&
]
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
thr_num
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
Tensor
tGS_gRow_flt
=
filter_zeros
(
tGS_gRow
);
Tensor
tGS_sRow_flt
=
filter_zeros
(
tGS_sRow
);
Tensor
tGS_cRow_flt
=
make_tensor
(
tGS_cRow
.
data
(),
make_layout
(
tGS_gRow_flt
.
shape
(),
tGS_cRow
.
stride
()));
for
(
int
i
=
0
;
i
<
size
(
tGS_gRow_flt
);
++
i
)
{
if
(
get
<
1
>
(
tGS_cRow_flt
(
i
))
>=
size
<
1
>
(
CtaTileShapeMNK
{}))
{
continue
;
// OOB of SMEM,
}
if
(
elem_less
(
tGS_cRow_flt
(
i
),
make_coord
(
get
<
0
>
(
residue_tCcRow
),
get
<
1
>
(
residue_tCcRow
))))
{
tGS_sRow_flt
(
i
)
=
tGS_gRow_flt
(
i
);
}
else
{
tGS_sRow_flt
(
i
)
=
Element
(
0
);
// Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize
();
}
CUTLASS_DEVICE
void
begin_loop
(
int
epi_m
,
int
epi_n
)
{
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
)
);
if
(
!
params
.
row_broadcast
)
return
;
// Do not issue LDS when row is scalar
Tensor
tSR_sRow_flt
=
filter_zeros
(
tSR_sRow
(
_
,
_
,
_
,
epi_m
,
epi_n
));
Tensor
tSR_rRow_flt
=
filter_zeros
(
tSR_rRow
)
;
copy
(
tSR_sRow_flt
,
tSR_rRow_flt
);
}
}
...
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
t
C
rRow
(
epi_v
*
FragmentSize
+
i
);
frg_row
[
i
]
=
t
SR_
rRow
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_row
;
...
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
using
ThreadCount
=
decltype
(
size
(
args
.
tiled_copy
));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
(
_
,
_
,
l
),
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
));
// (CTA_M, CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem
),
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})),
make_shape
(
_0
{},
_1
{}));
// (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto
tiled_g2s
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
Layout
<
Shape
<
_1
,
ThreadCount
>
,
Stride
<
_0
,
_1
>>
{},
Layout
<
_1
>
{});
auto
thr_g2s
=
tiled_g2s
.
get_slice
(
args
.
thread_idx
);
Tensor
tGS_gRow
=
thr_g2s
.
partition_S
(
gRow
);
Tensor
tGS_sRow
=
thr_g2s
.
partition_D
(
sRow
);
//// G2S: Coord
auto
cRow
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tGS_cRow
=
thr_g2s
.
partition_S
(
cRow
);
//// S2R: Smem to Reg
Tensor
tSR_sRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tSR_rRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tSR_sRow
));
// (CPY,CPY_M,CPY_N)
return
ConsumerStoreCallbacks
<
decltype
(
tGS_gRow
),
decltype
(
tGS_sRow
),
decltype
(
tGS_cRow
),
decltype
(
tiled_g2s
),
decltype
(
tSR_sRow
),
decltype
(
tSR_rRow
),
decltype
(
args
.
tCcD
),
decltype
(
args
.
residue_cD
),
ThreadCount
>
(
tGS_gRow
,
tGS_sRow
,
tGS_cRow
,
tiled_g2s
,
tSR_sRow
,
tSR_rRow
,
args
.
tCcD
,
args
.
residue_cD
,
ThreadCount
{},
params
);
}
};
...
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
>
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
params
(
params
)
{}
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
int
m
;
CUTLASS_DEVICE
void
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
...
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_
aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
copy_
if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
...
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
params
);
}
};
...
...
csrc/quantization/cutlass_w8a8/common.hpp
View file @
e7c1b7f3
#pragma once
#include "cutlass/cutlass.h"
#include <climits>
/**
* Helper function for checking CUTLASS errors
...
...
@@ -10,3 +11,17 @@
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
inline
int
get_cuda_max_shared_memory_per_block_opt_in
(
int
const
device
)
{
int
max_shared_mem_per_block_opt_in
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem_per_block_opt_in
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
return
max_shared_mem_per_block_opt_in
;
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
e7c1b7f3
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
{
private:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm75_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
cutlass
::
arch
::
OpMultiplyAdd
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm80_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
void
cutlass_scaled_mm_sm8
9
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm8
0
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
vllm
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
FP8MathOperator
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
inline
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_default
{
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M256
{
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M64
{
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M32
{
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm75_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm75_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM256
=
typename
sm75_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128
=
Cutlass2xGemmDefault
;
using
Cutlass2xGemmM64
=
typename
sm75_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm75_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using
FallbackGemm
=
Cutlass2xGemmDefault
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// M in [1, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM256
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_fp8_fallback_gemm
{
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
;
};
struct
sm89_fp8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M128
{
// M in (64, 128]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8196
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
static
const
int32_t
MainLoopStages
=
5
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
24576
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_fp8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_fp8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_fp8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_fp8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_fp8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_fp8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_int8_fallback_gemm
{
// Shared mem requirement : 61440
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
static
int32_t
const
MainLoopStages
=
5
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
struct
sm89_int8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M128
{
// M in (64, 128]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_int8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_int8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_int8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_int8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_int8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_int8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
e7c1b7f3
...
...
@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
...
...
@@ -44,11 +42,6 @@ using namespace cute;
namespace
{
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
...
...
@@ -64,6 +57,24 @@ struct enable_sm90_or_later : Kernel {
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
...
...
@@ -81,21 +92,13 @@ struct enable_sm90_or_later : Kernel {
per row or column.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogue
{
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleBDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -125,6 +128,50 @@ struct ScaledEpilogue {
}
};
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
using
Bias_Args
=
typename
Bias
::
Arguments
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
Bias_Args
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
())};
return
ArgumentType
{
a_args
,
{
b_args
},
bias_args
};
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
...
@@ -194,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
...
...
@@ -225,24 +272,26 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
int32_t
M
>
struct
sm90_fp8_config
{
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
...
...
@@ -250,14 +299,14 @@ struct sm90_fp8_config {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
128
>
{
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
...
...
@@ -265,7 +314,8 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
64
>
{
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
...
...
@@ -278,6 +328,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
,
...
...
@@ -291,11 +413,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
0
>::
Cutlass3xGemm
;
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
64
>::
Cutlass3xGemm
;
typename
sm90_fp8_config
_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
128
>::
Cutlass3xGemm
;
typename
sm90_fp8_config
_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
...
...
@@ -316,49 +439,111 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_
caller
<
cutlass_3x_gemm
<
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_gemm_
sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_3x_gemm
<
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Scaled
Epilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Scaled
Epilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
e7c1b7f3
...
...
@@ -6,28 +6,49 @@
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
89
)
{
return
CUDA_VERSION
>=
12040
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
...
...
@@ -50,6 +71,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
if
(
version_num
>=
90
)
{
...
...
@@ -57,19 +83,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#endif
}
else
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
e7c1b7f3
...
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/fp8/common.cu
View file @
e7c1b7f3
...
...
@@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "../../reduction_utils.cuh"
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
...
...
@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template
<
typename
scalar_t
>
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
scaled_fp8_conversion
(
const
scalar_t
val
,
const
float
inverted_scale
)
{
float
x
=
static_cast
<
float
>
(
val
)
*
inverted_scale
;
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
}
...
...
@@ -40,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
...
...
@@ -87,6 +95,70 @@ typedef struct __align__(4) {
}
float8x4_t
;
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
...
...
@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const
float
inverted_scale
=
1.0
f
/
(
*
scale
);
scaled_fp8_conversion_vec
<
scalar_t
,
true
>
(
out
,
input
,
inverted_scale
,
num_elems
,
tid
,
blockDim
.
x
*
gridDim
.
x
);
}
// Vectorized input/output to better utilize memory bandwidth.
const
vec4_t
<
scalar_t
>*
vectorized_in
=
reinterpret_cast
<
const
vec4_t
<
scalar_t
>*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
template
<
typename
scalar_t
>
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
int
num_vec_elems
=
num_elems
>>
2
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
#pragma unroll 4
for
(
int
i
=
tid
;
i
<
num_vec_elems
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
c10
::
Float8_e4m3fn
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
out_vec
.
x
=
scaled_fp8_conversion
(
in_vec
.
x
,
inverted_scale
);
out_vec
.
y
=
scaled_fp8_conversion
(
in_vec
.
y
,
inverted_scale
);
out_vec
.
z
=
scaled_fp8_conversion
(
in_vec
.
z
,
inverted_scale
);
out_vec
.
w
=
scaled_fp8_conversion
(
in_vec
.
w
,
inverted_scale
);
vectorized_out
[
i
]
=
out_vec
;
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
float
absmax_val
=
0.0
f
;
if
(
can_vectorize
)
{
absmax_val
=
thread_max_vec
(
token_input
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
absmax_val
=
max
(
absmax_val
,
fabs
(
x
));
}
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
out
[
i
]
=
scaled_fp8_conversion
(
input
[
i
],
inverted_scale
);
float
const
block_absmax_val_maybe
=
blockReduceMax
(
absmax_val
);
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
scale_ub
)
{
token_scale
=
min
(
block_absmax_val_maybe
,
*
scale_ub
);
}
else
{
token_scale
=
block_absmax_val_maybe
;
}
// token scale computation
token_scale
=
max
(
token_scale
/
FP8_E4M3_MAX
,
min_scaling_factor
);
scale
[
token_idx
]
=
token_scale
;
}
__syncthreads
();
// Note that we don't use inverted scales so we can match FBGemm impl.
if
(
can_vectorize
)
{
scaled_fp8_conversion_vec
<
scalar_t
,
false
>
(
token_output
,
token_input
,
token_scale
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
token_output
[
i
]
=
scaled_fp8_conversion
<
false
>
(
static_cast
<
float
>
(
token_input
[
i
]),
token_scale
);
}
}
}
}
// namespace vllm
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
...
...
@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
...
...
@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
void
dynamic_per_token_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scales
,
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
});
}
csrc/quantization/fp8/fp8_marlin.cu
0 → 100644
View file @
e7c1b7f3
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
fp8_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
FP16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
typename
ScalarType
<
half
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit
<
nv_bfloat16
>
(
int
q
)
{
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
BF16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr
uint32_t
BIAS
=
(
BIAS_OFFSET
+
127
)
<<
23
;
const
nv_bfloat162
bias_reg
=
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to bfloat162 and apply bias
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
constexpr
int
pack_factor
=
32
/
num_bits
;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We scale a `half2` tile in row-major layout for column-wise quantization.
int
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
thread_block_reduce
();
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
)
{
int
tb_n
=
th_config
.
thread_n
;
// Get max scale groups per thread-block
// Fixed for channelwise
int
tb_groups
=
1
;
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
-
1
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
int
*
locks
=
(
int
*
)
workspace
;
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
}
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
2
,
"b_scales rank = "
,
b_rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Channelwise only for FP8
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
1
)
num_groups
=
b_scales
.
size
(
0
);
// Verify workspace size
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
fp8_marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
fp8_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
{
TORCH_CHECK
(
false
,
"fp8_marlin_gemm only supports bfloat16 and float16"
);
}
return
c
;
}
#endif
csrc/quantization/fp8/nvidia/quant_utils.cuh
View file @
e7c1b7f3
...
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
// float -> fp8
...
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
0 → 100644
View file @
e7c1b7f3
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
tile_n_ints
=
tile_n_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_ints
/
4
;
constexpr
int
stage_k_threads
=
tile_k_size
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int
first_n_packed
=
first_n
/
pack_factor
;
int4
*
sh_ptr
=
sh
+
stage_size
*
pipe
;
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k
+
k_id
)
*
(
size_n
/
pack_factor
)
+
first_n_packed
+
(
n_id
*
4
)])));
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
int
cur_n_packed
=
cur_n
/
pack_factor
;
int
cur_n_pos
=
cur_n
%
pack_factor
;
constexpr
int
sh_stride
=
tile_n_ints
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
// Undo interleaving
int
cur_n_pos_unpacked
;
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
undo_pack
[
8
]
=
{
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
else
{
constexpr
int
undo_pack
[
4
]
=
{
0
,
2
,
1
,
3
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
uint32_t
vals
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
packed_src_0
=
sh_stage_int_ptr
[
cur_n_packed
+
sh_stride
*
cur_elem
];
int
packed_src_1
=
sh_stage_int_ptr
[
cur_n_packed
+
(
8
/
pack_factor
)
+
sh_stride
*
cur_elem
];
vals
[
i
]
=
(
packed_src_0
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
packed_src_1
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
(
b_q_weight
.
size
(
0
)
==
size_k
,
"b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
" is not size_k = "
,
size_k
);
TORCH_CHECK
((
size_n
/
pack_factor
)
==
b_q_weight
.
size
(
1
),
"Shape mismatch: b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
", size_n = "
,
size_n
,
", pack_factor = "
,
pack_factor
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
)
CALL_IF
(
8
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
);
}
return
out
;
}
#endif
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
e7c1b7f3
...
...
@@ -19,8 +19,9 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
...
...
@@ -32,7 +33,7 @@ inline std::string str(T x) {
return
std
::
to_string
(
x
);
}
namespace
gptq_
marlin
{
namespace
marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
@@ -59,23 +60,27 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
...
...
@@ -264,6 +269,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return
frag_b
;
}
// Zero-point dequantizers
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit_zp
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit_zp
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit_zp
<
half
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit_zp
<
nv_bfloat16
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388608.
f
;
fp32_intermediates
[
1
]
-=
8388608.
f
;
fp32_intermediates
[
2
]
-=
8388608.
f
;
fp32_intermediates
[
3
]
-=
8388608.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
...
...
@@ -277,6 +390,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
template
<
typename
scalar_t
>
__device__
inline
void
sub_zp
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
scalar_t2
&
frag_zp
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
zp
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
...
...
@@ -404,6 +528,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
...
...
@@ -411,14 +536,18 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
...
...
@@ -437,6 +566,7 @@ __global__ void Marlin(
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
constexpr
int
pack_factor
=
32
/
num_bits
;
...
...
@@ -471,6 +601,8 @@ __global__ void Marlin(
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
...
...
@@ -478,6 +610,7 @@ __global__ void Marlin(
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
// Compute all information about the current slice which is required for
...
...
@@ -508,6 +641,7 @@ __global__ void Marlin(
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
par_id
++
;
}
};
init_slice
();
...
...
@@ -566,6 +700,13 @@ __global__ void Marlin(
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
...
@@ -605,6 +746,19 @@ __global__ void Marlin(
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
...
...
@@ -616,6 +770,18 @@ __global__ void Marlin(
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
...
...
@@ -664,14 +830,17 @@ __global__ void Marlin(
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
...
...
@@ -777,6 +946,28 @@ __global__ void Marlin(
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
...
...
@@ -784,6 +975,12 @@ __global__ void Marlin(
cp_async_fence
();
};
auto
fetch_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
...
...
@@ -932,8 +1129,82 @@ __global__ void Marlin(
}
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
if
constexpr
(
num_bits
==
4
)
{
int
zp_quant
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_shift
=
zp_quant
>>
8
;
frag_zp_0
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant
);
frag_zp_1
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant_shift
);
}
else
{
int
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
frag_zp_0
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_0
);
frag_zp_1
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_1
);
}
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -944,16 +1215,32 @@ __global__ void Marlin(
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant_shift
);
}
else
{
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_1
);
}
else
{
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
}
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
}
// Apply scale to frag_b0
...
...
@@ -967,6 +1254,11 @@ __global__ void Marlin(
}
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
...
...
@@ -1048,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
...
...
@@ -1109,6 +1401,53 @@ __global__ void Marlin(
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
...
...
@@ -1189,6 +1528,12 @@ __global__ void Marlin(
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
if
constexpr
(
has_zp
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
...
...
@@ -1197,6 +1542,7 @@ __global__ void Marlin(
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
...
...
@@ -1217,6 +1563,7 @@ __global__ void Marlin(
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
...
...
@@ -1325,7 +1672,11 @@ __global__ void Marlin(
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
...
...
@@ -1354,6 +1705,7 @@ __global__ void Marlin(
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
...
...
@@ -1363,22 +1715,24 @@ __global__ void Marlin(
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER &&
group_blocks == GROUP_BLOCKS &&
\
num_threads == NUM_THREADS) {
\
has_act_order == HAS_ACT_ORDER &&
has_zp == HAS_ZP &&
\
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {
\
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>,
\
HAS_ZP,
GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
}
typedef
struct
{
...
...
@@ -1517,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
true
;
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
}
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
...
...
@@ -1548,44 +1923,77 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
use_fp32_reduce
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
// TODO: remove alias when we start supporting other 8bit types
int
num_bits
=
q_type
.
size_bits
();
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
...
...
@@ -1664,7 +2072,9 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
...
...
@@ -1701,28 +2111,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
4
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
4
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
8
,
4
,
8
,
128
)
AWQ_CALL_IF
(
4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
", num_groups = "
,
num_groups
,
", group_size = "
,
group_size
,
", thread_m_blocks = "
,
thread_m_blocks
,
", thread_n_blocks = "
,
thread_n_blocks
,
", thread_k_blocks = "
,
thread_k_blocks
,
", num_bits = "
,
num_bits
);
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
...
...
@@ -1730,17 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
||
*
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
->
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
->
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
...
...
@@ -1749,16 +2175,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_
marlin
::
tile_size
==
0
,
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
...
...
@@ -1772,6 +2197,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_zeros
.
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
...
...
@@ -1784,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// Alloc C tmp buffer that is going to be used for the global reduce
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
reduce_max_m
=
0
;
reduce_n
=
0
;
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
...
...
@@ -1805,8 +2244,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
b_
rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_
rank
==
2
,
"b_scales rank = "
,
b_
rank
,
" is not 2"
);
int
rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_scales rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
...
...
@@ -1832,34 +2271,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_zeros rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_zeros
.
size
(
0
)
==
num_groups
,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
// Verify workspace size
TORCH_CHECK
(
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
gptq_
marlin
::
marlin_mm
_f16i4
<
half
>
(
marlin
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_
marlin
::
max_par
);
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
gptq_
marlin
::
marlin_mm
_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BF
loat
16
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
f
loat
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
e7c1b7f3
#include "gptq_marlin.cuh"
namespace
gptq_marlin
{
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
...
...
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
...
...
@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
}
}
}
// namespace
gptq_
marlin
#define CALL_IF(NUM_BITS, HAS_PERM)
\
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {
\
cudaFuncSetAttribute(
\
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads,
\
NUM_BITS,
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
\
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads, NUM_BITS, \
HAS_PERM>
\
<<<blocks,
gptq_
marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
\
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads,
NUM_BITS,
\
HAS_PERM>,
\
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads, NUM_BITS,
\
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(
\
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_
marlin
::
tile_n_size
);
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
...
@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
({
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment