Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7c1b7f3
Commit
e7c1b7f3
authored
Sep 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.4-dtk24.04.1'
parents
7462218e
04c62b93
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4188 additions
and
666 deletions
+4188
-666
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+0
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+1
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+0
-23
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+148
-90
csrc/quantization/cutlass_w8a8/common.hpp
csrc/quantization/cutlass_w8a8/common.hpp
+15
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+86
-297
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+340
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
+123
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
+139
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
+368
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
+353
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+246
-61
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+36
-10
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+2
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+155
-28
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+1305
-0
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+3
-0
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+269
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+572
-122
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+27
-33
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
csrc/quantization/aqlm/gemm_kernels.cu
View file @
e7c1b7f3
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
}
__syncthreads
();
__syncthreads
();
float
res
=
0
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
while
(
iters
--
)
{
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
...
...
csrc/quantization/awq/dequantize.cuh
View file @
e7c1b7f3
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return
result
;
return
result
;
#endif
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
}
// namespace awq
}
// namespace awq
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
e7c1b7f3
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace
vllm
{
namespace
vllm
{
namespace
awq
{
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
scaling_factors_shared
[
N
];
__shared__
half
zeros_shared
[
N
];
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t
B_loaded
=
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
// q * scale - zero * scale.
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__
void
__launch_bounds__
(
64
)
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
View file @
e7c1b7f3
...
@@ -64,8 +64,6 @@ using namespace detail;
...
@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast
// Row vector broadcast
template
<
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
int
Stages
,
class
CtaTileShapeMNK
,
class
CtaTileShapeMNK
,
class
Element
,
class
Element
,
...
@@ -73,14 +71,12 @@ template<
...
@@ -73,14 +71,12 @@ template<
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
>
struct
Sm90RowOrScalarBroadcast
{
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
Stages
==
0
,
"Row broadcast doesn't support smem usage"
);
static_assert
(
static_assert
(
is_static_v
<
decltype
(
take
<
0
,
2
>
(
StrideMNL
{}))
>
);
// batch stride can be dynamic or static
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
static_assert
(
take
<
0
,
2
>
(
StrideMNL
{})
==
Stride
<
_0
,
_1
>
{});
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
struct
SharedStorage
{
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
>
smem
;
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
};
};
// This struct has been modified to have a bool indicating that ptr_row is a
// This struct has been modified to have a bool indicating that ptr_row is a
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return
args
;
return
args
;
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
template
<
class
ProblemShape
>
static
size_t
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
,
:
params
(
params
)
smem
_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem
_row
.
data
()))
{
}
,
smem
(
const_cast
<
Element
*>
(
shared_storage
.
smem
.
data
()))
{
}
Params
params
;
Params
params
;
Element
*
smem
_row
;
Element
*
smem
=
nullptr
;
CUTLASS_DEVICE
bool
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
is_producer_load_needed
()
const
{
return
tru
e
;
return
fals
e
;
}
}
CUTLASS_DEVICE
bool
CUTLASS_DEVICE
bool
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
params
.
ptr_row
==
nullptr
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
}
}
template
<
int
EpiTiles
,
class
R
Tensor
,
class
STensor
>
template
<
class
GS_GTensor
,
class
GS_STensor
,
class
GS_CTensor
,
class
Tiled_G2S
,
class
SR_S
Tensor
,
class
S
R_R
Tensor
,
class
CTensor
,
class
ThrResidue
,
class
ThrNum
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
ConsumerStoreCallbacks
(
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
GS_GTensor
tGS_gRow_
,
GS_STensor
tGS_sRow_
,
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
GS_CTensor
tGS_cRow_
,
Tiled_G2S
tiled_g2s_
,
params
(
params
)
{}
SR_STensor
tSR_sRow_
,
SR_RTensor
tSR_rRow_
,
CTensor
tCcRow_
,
ThrResidue
residue_tCcRow_
,
ThrNum
thr_num_
,
Params
const
&
params_
)
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
:
tGS_gRow
(
tGS_gRow_
)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
,
tGS_sRow
(
tGS_sRow_
)
,
tGS_cRow
(
tGS_cRow_
)
,
tiled_G2S
(
tiled_g2s_
)
,
tSR_sRow
(
tSR_sRow_
)
,
tSR_rRow
(
tSR_rRow_
)
,
tCcRow
(
tCcRow_
)
,
residue_tCcRow
(
residue_tCcRow_
)
,
params
(
params_
)
{}
GS_GTensor
tGS_gRow
;
// (CPY,CPY_M,CPY_N)
GS_STensor
tGS_sRow
;
// (CPY,CPY_M,CPY_N)
GS_CTensor
tGS_cRow
;
// (CPY,CPY_M,CPY_N)
Tiled_G2S
tiled_G2S
;
SR_STensor
tSR_sRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor
tSR_rRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor
tCcRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue
residue_tCcRow
;
// (m, n)
ThrNum
thr_num
;
Params
const
&
params
;
Params
const
&
params
;
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
begin
(
)
{
if
(
!
params
.
row_broadcast
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
t
C
rRow
,
*
(
params
.
ptr_row
));
fill
(
t
SR_
rRow
,
*
(
params
.
ptr_row
));
return
;
return
;
}
}
auto
synchronize
=
[
&
]
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
thr_num
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
Tensor
tGS_gRow_flt
=
filter_zeros
(
tGS_gRow
);
Tensor
tGS_sRow_flt
=
filter_zeros
(
tGS_sRow
);
Tensor
tGS_cRow_flt
=
make_tensor
(
tGS_cRow
.
data
(),
make_layout
(
tGS_gRow_flt
.
shape
(),
tGS_cRow
.
stride
()));
for
(
int
i
=
0
;
i
<
size
(
tGS_gRow_flt
);
++
i
)
{
if
(
get
<
1
>
(
tGS_cRow_flt
(
i
))
>=
size
<
1
>
(
CtaTileShapeMNK
{}))
{
continue
;
// OOB of SMEM,
}
if
(
elem_less
(
tGS_cRow_flt
(
i
),
make_coord
(
get
<
0
>
(
residue_tCcRow
),
get
<
1
>
(
residue_tCcRow
))))
{
tGS_sRow_flt
(
i
)
=
tGS_gRow_flt
(
i
);
}
else
{
tGS_sRow_flt
(
i
)
=
Element
(
0
);
// Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize
();
}
CUTLASS_DEVICE
void
begin_loop
(
int
epi_m
,
int
epi_n
)
{
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
if
(
!
params
.
row_broadcast
)
return
;
// Do not issue LDS when row is scalar
// (only works if 0-strides are in same location, which is by construction)
Tensor
tSR_sRow_flt
=
filter_zeros
(
tSR_sRow
(
_
,
_
,
_
,
epi_m
,
epi_n
));
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
Tensor
tSR_rRow_flt
=
filter_zeros
(
tSR_rRow
)
;
copy
_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
)
);
copy
(
tSR_sRow_flt
,
tSR_rRow_flt
);
}
}
}
}
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
t
C
rRow
(
epi_v
*
FragmentSize
+
i
);
frg_row
[
i
]
=
t
SR_
rRow
(
epi_v
*
FragmentSize
+
i
);
}
}
return
frg_row
;
return
frg_row
;
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
>
>
CUTLASS_DEVICE
auto
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
using
ThreadCount
=
decltype
(
size
(
args
.
tiled_copy
));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
Tensor
gRow
=
local_tile
(
mRow
(
_
,
_
,
l
),
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
));
// (CTA_M, CTA_N)
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem
),
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})),
make_shape
(
_0
{},
_1
{}));
// (CTA_M, CTA_N)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
//// G2S: Gmem to Smem
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
auto
tiled_g2s
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
Layout
<
Shape
<
_1
,
ThreadCount
>
,
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
Stride
<
_0
,
_1
>>
{},
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
Layout
<
_1
>
{});
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
auto
thr_g2s
=
tiled_g2s
.
get_slice
(
args
.
thread_idx
);
Tensor
tGS_gRow
=
thr_g2s
.
partition_S
(
gRow
);
Tensor
tGS_sRow
=
thr_g2s
.
partition_D
(
sRow
);
//// G2S: Coord
auto
cRow
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tGS_cRow
=
thr_g2s
.
partition_S
(
cRow
);
//// S2R: Smem to Reg
Tensor
tSR_sRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tSR_rRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tSR_sRow
));
// (CPY,CPY_M,CPY_N)
return
ConsumerStoreCallbacks
<
decltype
(
tGS_gRow
),
decltype
(
tGS_sRow
),
decltype
(
tGS_cRow
),
decltype
(
tiled_g2s
),
decltype
(
tSR_sRow
),
decltype
(
tSR_rRow
),
decltype
(
args
.
tCcD
),
decltype
(
args
.
residue_cD
),
ThreadCount
>
(
tGS_gRow
,
tGS_sRow
,
tGS_cRow
,
tiled_g2s
,
tSR_sRow
,
tSR_rRow
,
args
.
tCcD
,
args
.
residue_cD
,
ThreadCount
{},
params
);
}
}
};
};
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return
args
;
return
args
;
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
template
<
class
ProblemShape
>
static
size_t
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return
EmptyProducerLoadCallbacks
{};
return
EmptyProducerLoadCallbacks
{};
}
}
template
<
class
GTensor
,
class
RTensor
>
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
ConsumerStoreCallbacks
(
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
GTensor
&&
tCgCol
,
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
RTensor
&&
tCrCol
,
params
(
params
)
{}
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
Params
const
&
params
;
int
m
;
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
begin
()
{
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
return
;
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
// (only works if 0-strides are in same location, which is by construction)
copy_
aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
copy_
if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
// Generate an identity tensor matching the shape of the global tensor and
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
params
);
}
}
};
};
...
...
csrc/quantization/cutlass_w8a8/common.hpp
View file @
e7c1b7f3
#pragma once
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include <climits>
/**
/**
* Helper function for checking CUTLASS errors
* Helper function for checking CUTLASS errors
...
@@ -10,3 +11,17 @@
...
@@ -10,3 +11,17 @@
TORCH_CHECK(status == cutlass::Status::kSuccess, \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
cutlassGetStatusString(status)) \
}
}
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
inline
int
get_cuda_max_shared_memory_per_block_opt_in
(
int
const
device
)
{
int
max_shared_mem_per_block_opt_in
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem_per_block_opt_in
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
return
max_shared_mem_per_block_opt_in
;
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
e7c1b7f3
#include <stddef.h>
#include <stddef.h>
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
{
private:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
cutlass
::
multiplies
,
float
,
float
,
typename
...
EpilogueArgs
>
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
void
cutlass_scaled_mm_sm75_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
using
EVTCompute0
=
EpilogueArgs
&&
...
epilogue_args
)
{
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
evt_compute_args
;
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
cutlass
::
arch
::
OpMultiplyAdd
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
}
// namespace
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
"currently bias dtype must match output dtype "
,
out
.
dtype
());
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
torch
::
Tensor
const
&
b
,
typename
...
EpilogueArgs
>
torch
::
Tensor
const
&
a_scales
,
void
cutlass_scaled_mm_sm80_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
half_t
,
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
void
cutlass_scaled_mm_sm8
9
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm8
0
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
,
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
ScaledEpilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
else
{
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
WarpShape
,
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
}
}
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
vllm
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
FP8MathOperator
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
inline
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_default
{
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M256
{
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M64
{
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M32
{
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm75_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm75_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM256
=
typename
sm75_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128
=
Cutlass2xGemmDefault
;
using
Cutlass2xGemmM64
=
typename
sm75_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm75_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using
FallbackGemm
=
Cutlass2xGemmDefault
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// M in [1, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM256
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_fp8_fallback_gemm
{
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
;
};
struct
sm89_fp8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M128
{
// M in (64, 128]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8196
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
static
const
int32_t
MainLoopStages
=
5
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
24576
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_fp8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_fp8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_fp8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_fp8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_fp8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_fp8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
0 → 100644
View file @
e7c1b7f3
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_int8_fallback_gemm
{
// Shared mem requirement : 61440
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
static
int32_t
const
MainLoopStages
=
5
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
struct
sm89_int8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M128
{
// M in (64, 128]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_int8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_int8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_int8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_int8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_int8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_int8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
e7c1b7f3
...
@@ -18,8 +18,6 @@
...
@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
...
@@ -44,11 +42,6 @@ using namespace cute;
...
@@ -44,11 +42,6 @@ using namespace cute;
namespace
{
namespace
{
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// reduce the size of the compiled binary.
...
@@ -64,6 +57,24 @@ struct enable_sm90_or_later : Kernel {
...
@@ -64,6 +57,24 @@ struct enable_sm90_or_later : Kernel {
}
}
};
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
/*
This epilogue function defines a quantized GEMM operation similar to
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
torch.scaled_mm_.
...
@@ -81,21 +92,13 @@ struct enable_sm90_or_later : Kernel {
...
@@ -81,21 +92,13 @@ struct enable_sm90_or_later : Kernel {
per row or column.
per row or column.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogue
{
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
private:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
using
ScaleA
=
typename
SUPER
::
ScaleA
;
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
using
ScaleB
=
typename
SUPER
::
ScaleB
;
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleBDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
multiplies
,
float
,
float
,
...
@@ -125,6 +128,50 @@ struct ScaledEpilogue {
...
@@ -125,6 +128,50 @@ struct ScaledEpilogue {
}
}
};
};
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
using
Bias_Args
=
typename
Bias
::
Arguments
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
Bias_Args
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
())};
return
ArgumentType
{
a_args
,
{
b_args
},
bias_args
};
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
@@ -194,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -194,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
...
@@ -225,24 +272,26 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -225,24 +272,26 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
CUTLASS_CHECK
(
status
);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
int32_t
M
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config
{
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
...
@@ -250,14 +299,14 @@ struct sm90_fp8_config {
...
@@ -250,14 +299,14 @@ struct sm90_fp8_config {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
128
>
{
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
...
@@ -265,7 +314,8 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
...
@@ -265,7 +314,8 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
64
>
{
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
...
@@ -278,6 +328,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
...
@@ -278,6 +328,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
};
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
}
// namespace
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
...
@@ -291,11 +413,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
...
@@ -291,11 +413,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
0
>::
Cutlass3xGemm
;
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
64
>::
Cutlass3xGemm
;
typename
sm90_fp8_config
_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config
<
InType
,
OutType
,
Epilogue
,
128
>::
Cutlass3xGemm
;
typename
sm90_fp8_config
_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
uint32_t
const
mp2
=
...
@@ -316,49 +439,111 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
...
@@ -316,49 +439,111 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
template
<
typename
InType
,
typename
OutType
,
torch
::
Tensor
const
&
b
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
torch
::
Tensor
const
&
a_scales
,
typename
...
EpilogueArgs
>
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
torch
::
Tensor
const
&
b
,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_
caller
<
cutlass_3x_gemm
<
return
cutlass_gemm_
sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
int8_t
,
cutlass
::
bfloat16_t
,
ScaledEpilogue
,
TileShape
,
ClusterShape
,
Epilogue
>
(
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_caller
<
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
cutlass_3x_gemm
<
int8_t
,
cutlass
::
half_t
,
ScaledEpilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
else
{
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Scaled
Epilogue
>
(
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Scaled
Epilogue
>
(
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...
);
}
}
}
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
#endif
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
e7c1b7f3
...
@@ -6,28 +6,49 @@
...
@@ -6,28 +6,49 @@
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
89
)
{
return
CUDA_VERSION
>=
12040
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
int32_t
major_capability
;
int32_t
major_capability
;
int32_t
minor_capability
;
int32_t
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
...
@@ -50,6 +71,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -50,6 +71,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
if
(
version_num
>=
90
)
{
if
(
version_num
>=
90
)
{
...
@@ -57,19 +83,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -57,19 +83,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#endif
#endif
}
else
if
(
version_num
==
89
)
{
}
else
if
(
version_num
==
89
)
{
// Ada Lovelace
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
}
else
if
(
version_num
>=
80
)
{
// Ampere
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
}
else
{
// Turing
// Turing
TORCH_CHECK
(
version_num
>=
75
);
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
}
}
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
e7c1b7f3
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
return
{};
// Squash missing return statement warning
}
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
return
{};
// Squash missing return statement warning
}
}
// The following macro is used to dispatch the conversion function based on
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/fp8/common.cu
View file @
e7c1b7f3
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "../../reduction_utils.cuh"
namespace
vllm
{
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
...
@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
...
@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template
<
typename
scalar_t
>
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
scaled_fp8_conversion
(
__device__
__forceinline__
c10
::
Float8_e4m3fn
scaled_fp8_conversion
(
const
scalar_t
val
,
const
float
inverted_scale
)
{
float
const
val
,
float
const
scale
)
{
float
x
=
static_cast
<
float
>
(
val
)
*
inverted_scale
;
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
}
}
...
@@ -40,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
...
@@ -40,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
__shared__
float
cache
[
1024
];
int
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
// the current thread in cache[threadIdx.x]
...
@@ -87,6 +95,70 @@ typedef struct __align__(4) {
...
@@ -87,6 +95,70 @@ typedef struct __align__(4) {
}
}
float8x4_t
;
float8x4_t
;
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
__global__
void
scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
...
@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
...
@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// Invert the scale so that we can use multiplications to avoid expensive
// division.
// division.
const
float
inverted_scale
=
1.0
f
/
(
*
scale
);
const
float
inverted_scale
=
1.0
f
/
(
*
scale
);
scaled_fp8_conversion_vec
<
scalar_t
,
true
>
(
out
,
input
,
inverted_scale
,
num_elems
,
tid
,
blockDim
.
x
*
gridDim
.
x
);
}
// Vectorized input/output to better utilize memory bandwidth.
template
<
typename
scalar_t
>
const
vec4_t
<
scalar_t
>*
vectorized_in
=
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
reinterpret_cast
<
const
vec4_t
<
scalar_t
>*>
(
input
);
c10
::
Float8_e4m3fn
*
__restrict__
out
,
float
*
__restrict__
scale
,
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
int
num_vec_elems
=
num_elems
>>
2
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
#pragma unroll 4
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
for
(
int
i
=
tid
;
i
<
num_vec_elems
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
c10
::
Float8_e4m3fn
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
(
in_vec
.
x
,
inverted_scale
);
// For vectorization, token_input and token_output pointers need to be
out_vec
.
y
=
scaled_fp8_conversion
(
in_vec
.
y
,
inverted_scale
);
// aligned at 8-byte and 4-byte addresses respectively.
out_vec
.
z
=
scaled_fp8_conversion
(
in_vec
.
z
,
inverted_scale
);
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
out_vec
.
w
=
scaled_fp8_conversion
(
in_vec
.
w
,
inverted_scale
);
vectorized_out
[
i
]
=
out_vec
;
float
absmax_val
=
0.0
f
;
if
(
can_vectorize
)
{
absmax_val
=
thread_max_vec
(
token_input
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
absmax_val
=
max
(
absmax_val
,
fabs
(
x
));
}
}
}
// Handle the remaining elements if num_elems is not divisible by 4
float
const
block_absmax_val_maybe
=
blockReduceMax
(
absmax_val
);
for
(
int
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
__shared__
float
token_scale
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
tid
==
0
)
{
out
[
i
]
=
scaled_fp8_conversion
(
input
[
i
],
inverted_scale
);
if
(
scale_ub
)
{
token_scale
=
min
(
block_absmax_val_maybe
,
*
scale_ub
);
}
else
{
token_scale
=
block_absmax_val_maybe
;
}
// token scale computation
token_scale
=
max
(
token_scale
/
FP8_E4M3_MAX
,
min_scaling_factor
);
scale
[
token_idx
]
=
token_scale
;
}
__syncthreads
();
// Note that we don't use inverted scales so we can match FBGemm impl.
if
(
can_vectorize
)
{
scaled_fp8_conversion_vec
<
scalar_t
,
false
>
(
token_output
,
token_input
,
token_scale
,
hidden_size
,
tid
,
blockDim
.
x
);
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
token_output
[
i
]
=
scaled_fp8_conversion
<
false
>
(
static_cast
<
float
>
(
token_input
[
i
]),
token_scale
);
}
}
}
}
}
}
// namespace vllm
}
// namespace vllm
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
torch
::
Tensor
const
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
int64_t
num_elems
=
input
.
numel
();
...
@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
});
}
}
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
torch
::
Tensor
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
int64_t
num_elems
=
input
.
numel
();
...
@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale
.
data_ptr
<
float
>
(),
num_elems
);
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
});
}
}
void
dynamic_per_token_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scales
,
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
});
}
csrc/quantization/fp8/fp8_marlin.cu
0 → 100644
View file @
e7c1b7f3
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
fp8_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
FP16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
typename
ScalarType
<
half
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit
<
nv_bfloat16
>
(
int
q
)
{
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
BF16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr
uint32_t
BIAS
=
(
BIAS_OFFSET
+
127
)
<<
23
;
const
nv_bfloat162
bias_reg
=
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to bfloat162 and apply bias
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
constexpr
int
pack_factor
=
32
/
num_bits
;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We scale a `half2` tile in row-major layout for column-wise quantization.
int
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
thread_block_reduce
();
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
)
{
int
tb_n
=
th_config
.
thread_n
;
// Get max scale groups per thread-block
// Fixed for channelwise
int
tb_groups
=
1
;
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
-
1
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
int
*
locks
=
(
int
*
)
workspace
;
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
}
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
2
,
"b_scales rank = "
,
b_rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Channelwise only for FP8
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
1
)
num_groups
=
b_scales
.
size
(
0
);
// Verify workspace size
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
fp8_marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
fp8_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
{
TORCH_CHECK
(
false
,
"fp8_marlin_gemm only supports bfloat16 and float16"
);
}
return
c
;
}
#endif
csrc/quantization/fp8/nvidia/quant_utils.cuh
View file @
e7c1b7f3
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE
,
fp8_type
);
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
return
(
uint8_t
)
res
;
#endif
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// float -> fp8
// float -> fp8
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// The following macro is used to dispatch the conversion function based on
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
0 → 100644
View file @
e7c1b7f3
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
tile_n_ints
=
tile_n_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_ints
/
4
;
constexpr
int
stage_k_threads
=
tile_k_size
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int
first_n_packed
=
first_n
/
pack_factor
;
int4
*
sh_ptr
=
sh
+
stage_size
*
pipe
;
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k
+
k_id
)
*
(
size_n
/
pack_factor
)
+
first_n_packed
+
(
n_id
*
4
)])));
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
int
cur_n_packed
=
cur_n
/
pack_factor
;
int
cur_n_pos
=
cur_n
%
pack_factor
;
constexpr
int
sh_stride
=
tile_n_ints
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
// Undo interleaving
int
cur_n_pos_unpacked
;
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
undo_pack
[
8
]
=
{
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
else
{
constexpr
int
undo_pack
[
4
]
=
{
0
,
2
,
1
,
3
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
uint32_t
vals
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
packed_src_0
=
sh_stage_int_ptr
[
cur_n_packed
+
sh_stride
*
cur_elem
];
int
packed_src_1
=
sh_stage_int_ptr
[
cur_n_packed
+
(
8
/
pack_factor
)
+
sh_stride
*
cur_elem
];
vals
[
i
]
=
(
packed_src_0
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
packed_src_1
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
(
b_q_weight
.
size
(
0
)
==
size_k
,
"b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
" is not size_k = "
,
size_k
);
TORCH_CHECK
((
size_n
/
pack_factor
)
==
b_q_weight
.
size
(
1
),
"Shape mismatch: b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
", size_n = "
,
size_n
,
", pack_factor = "
,
pack_factor
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
)
CALL_IF
(
8
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
);
}
return
out
;
}
#endif
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
e7c1b7f3
...
@@ -19,8 +19,9 @@
...
@@ -19,8 +19,9 @@
* Adapted from https://github.com/IST-DASLab/marlin
* Adapted from https://github.com/IST-DASLab/marlin
*/
*/
#include "gptq_marlin.cuh"
#include "marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
...
@@ -32,7 +33,7 @@ inline std::string str(T x) {
...
@@ -32,7 +33,7 @@ inline std::string str(T x) {
return
std
::
to_string
(
x
);
return
std
::
to_string
(
x
);
}
}
namespace
gptq_
marlin
{
namespace
marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
@@ -59,23 +60,27 @@ __global__ void Marlin(
...
@@ -59,23 +60,27 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
size_k
,
bool
is_k_full
)
{
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
return
torch
::
empty
({
1
,
1
});
...
@@ -264,6 +269,114 @@ dequant_8bit<nv_bfloat16>(int q) {
...
@@ -264,6 +269,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return
frag_b
;
return
frag_b
;
}
}
// Zero-point dequantizers
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit_zp
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit_zp
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit_zp
<
half
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit_zp
<
nv_bfloat16
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388608.
f
;
fp32_intermediates
[
1
]
-=
8388608.
f
;
fp32_intermediates
[
2
]
-=
8388608.
f
;
fp32_intermediates
[
3
]
-=
8388608.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
// only for grouped quantization.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -277,6 +390,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
...
@@ -277,6 +390,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
}
template
<
typename
scalar_t
>
__device__
inline
void
sub_zp
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
scalar_t2
&
frag_zp
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
zp
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
...
@@ -404,6 +528,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
...
@@ -404,6 +528,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const
int
stages
,
// number of stages for the async global->shared
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
// with a separate quantization scale
>
>
...
@@ -411,14 +536,18 @@ __global__ void Marlin(
...
@@ -411,14 +536,18 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// same size, which might involve multiple column "slices" (of width 16 *
...
@@ -437,6 +566,7 @@ __global__ void Marlin(
...
@@ -437,6 +566,7 @@ __global__ void Marlin(
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
constexpr
int
pack_factor
=
32
/
num_bits
;
constexpr
int
pack_factor
=
32
/
num_bits
;
...
@@ -471,6 +601,8 @@ __global__ void Marlin(
...
@@ -471,6 +601,8 @@ __global__ void Marlin(
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
if
(
slice_col_par
>=
n_tiles
)
{
...
@@ -478,6 +610,7 @@ __global__ void Marlin(
...
@@ -478,6 +610,7 @@ __global__ void Marlin(
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
}
// Compute all information about the current slice which is required for
// Compute all information about the current slice which is required for
...
@@ -508,6 +641,7 @@ __global__ void Marlin(
...
@@ -508,6 +641,7 @@ __global__ void Marlin(
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
locks
+=
n_tiles
;
slice_col
=
0
;
slice_col
=
0
;
par_id
++
;
}
}
};
};
init_slice
();
init_slice
();
...
@@ -566,6 +700,13 @@ __global__ void Marlin(
...
@@ -566,6 +700,13 @@ __global__ void Marlin(
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
@@ -605,6 +746,19 @@ __global__ void Marlin(
...
@@ -605,6 +746,19 @@ __global__ void Marlin(
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
// row-major in the latter case.
...
@@ -616,6 +770,18 @@ __global__ void Marlin(
...
@@ -616,6 +770,18 @@ __global__ void Marlin(
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
// Precompute which thread should not read memory in which iterations; this is
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
// when the batchsize is not a multiple of 16.
...
@@ -664,14 +830,17 @@ __global__ void Marlin(
...
@@ -664,14 +830,17 @@ __global__ void Marlin(
int4
*
sh_a
=
sh
;
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// Register storage for double buffer of shared memory reads.
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
// Zero accumulators.
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
auto
zero_accums
=
[
&
]()
{
...
@@ -777,6 +946,28 @@ __global__ void Marlin(
...
@@ -777,6 +946,28 @@ __global__ void Marlin(
}
}
}
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// Insert a fence even when we are winding down the pipeline to ensure that
...
@@ -784,6 +975,12 @@ __global__ void Marlin(
...
@@ -784,6 +975,12 @@ __global__ void Marlin(
cp_async_fence
();
cp_async_fence
();
};
};
auto
fetch_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// We only have `stages - 2` active fetches since we are double buffering
...
@@ -932,8 +1129,82 @@ __global__ void Marlin(
...
@@ -932,8 +1129,82 @@ __global__ void Marlin(
}
}
};
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
if
constexpr
(
num_bits
==
4
)
{
int
zp_quant
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_shift
=
zp_quant
>>
8
;
frag_zp_0
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant
);
frag_zp_1
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant_shift
);
}
else
{
int
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
frag_zp_0
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_0
);
frag_zp_1
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_1
);
}
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
// dequantization and matmul operations.
#pragma unroll
#pragma unroll
...
@@ -944,16 +1215,32 @@ __global__ void Marlin(
...
@@ -944,16 +1215,32 @@ __global__ void Marlin(
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
if
constexpr
(
has_zp
)
{
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
frag_b0
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant_shift
);
}
else
{
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
}
else
{
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
if
constexpr
(
has_zp
)
{
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
frag_b0
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_1
);
}
else
{
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
}
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
}
}
// Apply scale to frag_b0
// Apply scale to frag_b0
...
@@ -967,6 +1254,11 @@ __global__ void Marlin(
...
@@ -967,6 +1254,11 @@ __global__ void Marlin(
}
}
}
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b1
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
...
@@ -1048,7 +1340,7 @@ __global__ void Marlin(
...
@@ -1048,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
// results in FP16 (but still reduce with FP32 compute).
...
@@ -1109,6 +1401,53 @@ __global__ void Marlin(
...
@@ -1109,6 +1401,53 @@ __global__ void Marlin(
}
}
};
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
// in fragment layout.
...
@@ -1189,6 +1528,12 @@ __global__ void Marlin(
...
@@ -1189,6 +1528,12 @@ __global__ void Marlin(
}
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
}
if
constexpr
(
has_zp
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
}
...
@@ -1197,6 +1542,7 @@ __global__ void Marlin(
...
@@ -1197,6 +1542,7 @@ __global__ void Marlin(
init_same_group
(
0
);
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
};
...
@@ -1217,6 +1563,7 @@ __global__ void Marlin(
...
@@ -1217,6 +1563,7 @@ __global__ void Marlin(
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
slice_iters
>=
stages
);
...
@@ -1325,7 +1672,11 @@ __global__ void Marlin(
...
@@ -1325,7 +1672,11 @@ __global__ void Marlin(
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
}
if
(
last
)
// only the last block in a slice actually writes the result
if
(
last
)
// only the last block in a slice actually writes the result
...
@@ -1354,6 +1705,7 @@ __global__ void Marlin(
...
@@ -1354,6 +1705,7 @@ __global__ void Marlin(
}
else
{
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
start_pipes
();
start_pipes
();
...
@@ -1363,22 +1715,24 @@ __global__ void Marlin(
...
@@ -1363,22 +1715,24 @@ __global__ void Marlin(
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER &&
group_blocks == GROUP_BLOCKS &&
\
has_act_order == HAS_ACT_ORDER &&
has_zp == HAS_ZP &&
\
num_threads == NUM_THREADS) {
\
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {
\
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>,
\
HAS_ZP,
GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
HAS_ZP, GROUP_BLOCKS> \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
prob_k, locks); \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
}
}
typedef
struct
{
typedef
struct
{
...
@@ -1517,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
...
@@ -1517,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
true
;
return
true
;
}
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
}
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_act_order
,
bool
is_k_full
,
...
@@ -1548,44 +1923,77 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1548,44 +1923,77 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
thread_n
,
int
sms
,
int
max_par
)
{
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
int
sms
,
int
max_par
,
bool
use_fp32_reduce
)
{
"num_bits must be 4 or 8. Got = "
,
num_bits
);
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
// TODO: remove alias when we start supporting other 8bit types
int
num_bits
=
q_type
.
size_bits
();
int
tot_m
=
prob_m
;
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
...
@@ -1664,7 +2072,9 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1664,7 +2072,9 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
...
@@ -1701,28 +2111,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1701,28 +2111,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
}
// Define kernel configurations
if
(
false
)
{
if
(
false
)
{
}
}
CALL_IF
(
4
,
32
,
2
,
256
)
GPTQ_CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
8
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
32
,
2
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
8
,
4
,
8
,
128
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
AWQ_CALL_IF
(
4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
8
,
4
,
8
,
128
)
else
{
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
,
num_groups
,
", group_size = "
,
group_size
,
", num_groups = "
+
str
(
num_groups
)
+
", thread_m_blocks = "
,
thread_m_blocks
,
", group_size = "
+
str
(
group_size
)
+
", thread_n_blocks = "
,
thread_n_blocks
,
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_k_blocks = "
,
thread_k_blocks
,
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", num_bits = "
,
num_bits
);
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
...
@@ -1730,17 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1730,17 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
}
}
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
size_k
,
bool
is_k_full
)
{
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
// Verify num_bits
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
bool
is_k_full
,
bool
has_zp
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
bool
use_fp32_reduce
)
{
int
pack_factor
=
32
/
num_bits
;
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
||
*
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
->
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
->
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
// Verify A
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
...
@@ -1749,16 +2175,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1749,16 +2175,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = "
,
size_k
);
", size_k = "
,
size_k
);
// Verify B
// Verify B
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_
marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_
marlin
::
tile_size
);
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_
marlin
::
tile_size
==
0
,
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
", actual_size_n = "
,
actual_size_n
);
...
@@ -1772,6 +2197,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1772,6 +2197,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_zeros
.
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
...
@@ -1784,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1784,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// Alloc C tmp buffer that is going to be used for the global reduce
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
reduce_max_m
=
0
;
reduce_n
=
0
;
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
// auto -1)
int
thread_k
=
-
1
;
int
thread_k
=
-
1
;
...
@@ -1805,8 +2244,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1805,8 +2244,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
group_size
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
b_
rank
=
b_scales
.
sizes
().
size
();
int
rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_
rank
==
2
,
"b_scales rank = "
,
b_
rank
,
" is not 2"
);
TORCH_CHECK
(
rank
==
2
,
"b_scales rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
num_groups
=
b_scales
.
size
(
0
);
...
@@ -1832,34 +2271,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1832,34 +2271,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_zeros rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_zeros
.
size
(
0
)
==
num_groups
,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
// Verify workspace size
// Verify workspace size
TORCH_CHECK
(
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
gptq_
marlin
::
marlin_mm
_f16i4
<
half
>
(
marlin
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_n
,
sms
,
gptq_
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
gptq_
marlin
::
marlin_mm
_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BF
loat
16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
f
loat
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
is_k_full
,
num_groups
,
group_size
,
dev
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
gptq_marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
e7c1b7f3
#include "gptq_marlin.cuh"
#include "marlin.cuh"
namespace
gptq_marlin
{
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
int
size_k
,
int
size_n
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
...
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
int
size_k
,
int
size_n
)
{
...
@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
...
@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
}
}
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM)
\
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {
\
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute(
\
cudaFuncSetAttribute( \
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads,
\
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads,
NUM_BITS,
\
NUM_BITS,
HAS_PERM>, \
HAS_PERM>,
\
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
\
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads, NUM_BITS, \
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads, NUM_BITS,
\
HAS_PERM>
\
HAS_PERM> \
<<<blocks,
gptq_
marlin::repack_threads, max_shared_mem, stream>>>( \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(
\
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
\
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_
marlin
::
tile_k_size
);
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_
marlin
::
tile_n_size
);
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto
options
=
torch
::
TensorOptions
()
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
Tensor
out
=
torch
::
empty
(
torch
::
empty
({
size_k
/
gptq_marlin
::
tile_size
,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
options
);
// Detect if there is act_order
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment