Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1748 additions
and
822 deletions
+1748
-822
csrc/punica/punica_ops.h
csrc/punica/punica_ops.h
+0
-11
csrc/punica/torch_bindings.cpp
csrc/punica/torch_bindings.cpp
+0
-18
csrc/punica/type_convert.h
csrc/punica/type_convert.h
+0
-82
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+0
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+1
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+0
-23
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+148
-90
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+28
-496
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+340
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
+123
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
+139
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
+368
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
+353
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+11
-19
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+1
-7
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+2
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+1
-1
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+3
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+198
-73
csrc/quantization/marlin/dense/common/base.h
csrc/quantization/marlin/dense/common/base.h
+32
-0
No files found.
csrc/punica/punica_ops.h
deleted
100644 → 0
View file @
6b16ea2e
#pragma once
#include <torch/all.h>
void
dispatch_bgmv
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
);
void
dispatch_bgmv_low_level
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
,
int64_t
h_in
,
int64_t
h_out
,
int64_t
y_offset
);
csrc/punica/torch_bindings.cpp
deleted
100644 → 0
View file @
6b16ea2e
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()"
);
m
.
impl
(
"dispatch_bgmv"
,
torch
::
kCUDA
,
&
dispatch_bgmv
);
m
.
def
(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()"
);
m
.
impl
(
"dispatch_bgmv_low_level"
,
torch
::
kCUDA
,
&
dispatch_bgmv_low_level
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/punica/type_convert.h
deleted
100644 → 0
View file @
6b16ea2e
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef
__half
nv_half
;
typedef
__hip_bfloat16
nv_bfloat16
;
typedef
__hip_bfloat162
nv_bfloat162
;
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat162
make_bfloat162
(
__hip_bfloat16
val
)
{
return
__hip_bfloat162
{
val
,
val
};
}
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat162
make_bfloat162
(
__hip_bfloat16
vall
,
__hip_bfloat16
valr
)
{
return
__hip_bfloat162
{
vall
,
valr
};
}
template
<
typename
T_src
,
typename
T_dst
>
__TYPE_CONVERT__HOST_DEVICE__
inline
T_dst
convert_type
(
T_src
val
)
{
return
static_cast
<
T_dst
>
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
float
convert_type
<
__half
,
float
>
(
__half
val
)
{
return
__half2float
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__half
convert_type
<
float
,
__half
>
(
float
val
)
{
return
__float2half
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
float
convert_type
<
__hip_bfloat16
,
float
>
(
__hip_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat16
convert_type
<
float
,
__hip_bfloat16
>
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
template
<
typename
T
>
__TYPE_CONVERT__HOST_DEVICE__
inline
T
vllm_add
(
T
a
,
T
b
)
{
return
a
+
b
;
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__half
vllm_add
<
__half
>
(
__half
a
,
__half
b
)
{
return
__hadd
(
a
,
b
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat16
vllm_add
<
__hip_bfloat16
>
(
__hip_bfloat16
a
,
__hip_bfloat16
b
)
{
return
__hadd
(
a
,
b
);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
csrc/quantization/aqlm/gemm_kernels.cu
View file @
e661d594
...
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
__syncthreads
();
float
res
=
0
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
...
...
csrc/quantization/awq/dequantize.cuh
View file @
e661d594
...
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return
result
;
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// namespace awq
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
e661d594
...
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace
vllm
{
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
...
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
scaling_factors_shared
[
N
];
__shared__
half
zeros_shared
[
N
];
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
...
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
...
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
View file @
e661d594
...
...
@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
...
...
@@ -73,14 +71,12 @@ template<
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
static_assert
(
Stages
==
0
,
"Row broadcast doesn't support smem usage"
);
static_assert
(
is_static_v
<
decltype
(
take
<
0
,
2
>
(
StrideMNL
{}))
>
);
// batch stride can be dynamic or static
static_assert
(
take
<
0
,
2
>
(
StrideMNL
{})
==
Stride
<
_0
,
_1
>
{});
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
struct
SharedStorage
{
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
>
smem
;
};
// This struct has been modified to have a bool indicating that ptr_row is a
...
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
,
smem
_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem
_row
.
data
()))
{
}
:
params
(
params
)
,
smem
(
const_cast
<
Element
*>
(
shared_storage
.
smem
.
data
()))
{
}
Params
params
;
Element
*
smem
_row
;
Element
*
smem
=
nullptr
;
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
tru
e
;
return
fals
e
;
}
CUTLASS_DEVICE
bool
...
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
!
params
.
row_broadcast
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
return
EmptyProducerLoadCallbacks
{};
}
template
<
int
EpiTiles
,
class
R
Tensor
,
class
STensor
>
template
<
class
GS_GTensor
,
class
GS_STensor
,
class
GS_CTensor
,
class
Tiled_G2S
,
class
SR_S
Tensor
,
class
S
R_R
Tensor
,
class
CTensor
,
class
ThrResidue
,
class
ThrNum
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
params
(
params
)
{}
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
ConsumerStoreCallbacks
(
GS_GTensor
tGS_gRow_
,
GS_STensor
tGS_sRow_
,
GS_CTensor
tGS_cRow_
,
Tiled_G2S
tiled_g2s_
,
SR_STensor
tSR_sRow_
,
SR_RTensor
tSR_rRow_
,
CTensor
tCcRow_
,
ThrResidue
residue_tCcRow_
,
ThrNum
thr_num_
,
Params
const
&
params_
)
:
tGS_gRow
(
tGS_gRow_
)
,
tGS_sRow
(
tGS_sRow_
)
,
tGS_cRow
(
tGS_cRow_
)
,
tiled_G2S
(
tiled_g2s_
)
,
tSR_sRow
(
tSR_sRow_
)
,
tSR_rRow
(
tSR_rRow_
)
,
tCcRow
(
tCcRow_
)
,
residue_tCcRow
(
residue_tCcRow_
)
,
params
(
params_
)
{}
GS_GTensor
tGS_gRow
;
// (CPY,CPY_M,CPY_N)
GS_STensor
tGS_sRow
;
// (CPY,CPY_M,CPY_N)
GS_CTensor
tGS_cRow
;
// (CPY,CPY_M,CPY_N)
Tiled_G2S
tiled_G2S
;
SR_STensor
tSR_sRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor
tSR_rRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor
tCcRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue
residue_tCcRow
;
// (m, n)
ThrNum
thr_num
;
Params
const
&
params
;
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
begin
(
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
t
C
rRow
,
*
(
params
.
ptr_row
));
fill
(
t
SR_
rRow
,
*
(
params
.
ptr_row
));
return
;
}
auto
synchronize
=
[
&
]
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
thr_num
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
Tensor
tGS_gRow_flt
=
filter_zeros
(
tGS_gRow
);
Tensor
tGS_sRow_flt
=
filter_zeros
(
tGS_sRow
);
Tensor
tGS_cRow_flt
=
make_tensor
(
tGS_cRow
.
data
(),
make_layout
(
tGS_gRow_flt
.
shape
(),
tGS_cRow
.
stride
()));
for
(
int
i
=
0
;
i
<
size
(
tGS_gRow_flt
);
++
i
)
{
if
(
get
<
1
>
(
tGS_cRow_flt
(
i
))
>=
size
<
1
>
(
CtaTileShapeMNK
{}))
{
continue
;
// OOB of SMEM,
}
if
(
elem_less
(
tGS_cRow_flt
(
i
),
make_coord
(
get
<
0
>
(
residue_tCcRow
),
get
<
1
>
(
residue_tCcRow
))))
{
tGS_sRow_flt
(
i
)
=
tGS_gRow_flt
(
i
);
}
else
{
tGS_sRow_flt
(
i
)
=
Element
(
0
);
// Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize
();
}
CUTLASS_DEVICE
void
begin_loop
(
int
epi_m
,
int
epi_n
)
{
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
)
);
if
(
!
params
.
row_broadcast
)
return
;
// Do not issue LDS when row is scalar
Tensor
tSR_sRow_flt
=
filter_zeros
(
tSR_sRow
(
_
,
_
,
_
,
epi_m
,
epi_n
));
Tensor
tSR_rRow_flt
=
filter_zeros
(
tSR_rRow
)
;
copy
(
tSR_sRow_flt
,
tSR_rRow_flt
);
}
}
...
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
t
C
rRow
(
epi_v
*
FragmentSize
+
i
);
frg_row
[
i
]
=
t
SR_
rRow
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_row
;
...
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
using
ThreadCount
=
decltype
(
size
(
args
.
tiled_copy
));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
(
_
,
_
,
l
),
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
));
// (CTA_M, CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem
),
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})),
make_shape
(
_0
{},
_1
{}));
// (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto
tiled_g2s
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
Layout
<
Shape
<
_1
,
ThreadCount
>
,
Stride
<
_0
,
_1
>>
{},
Layout
<
_1
>
{});
auto
thr_g2s
=
tiled_g2s
.
get_slice
(
args
.
thread_idx
);
Tensor
tGS_gRow
=
thr_g2s
.
partition_S
(
gRow
);
Tensor
tGS_sRow
=
thr_g2s
.
partition_D
(
sRow
);
//// G2S: Coord
auto
cRow
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tGS_cRow
=
thr_g2s
.
partition_S
(
cRow
);
//// S2R: Smem to Reg
Tensor
tSR_sRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tSR_rRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tSR_sRow
));
// (CPY,CPY_M,CPY_N)
return
ConsumerStoreCallbacks
<
decltype
(
tGS_gRow
),
decltype
(
tGS_sRow
),
decltype
(
tGS_cRow
),
decltype
(
tiled_g2s
),
decltype
(
tSR_sRow
),
decltype
(
tSR_rRow
),
decltype
(
args
.
tCcD
),
decltype
(
args
.
residue_cD
),
ThreadCount
>
(
tGS_gRow
,
tGS_sRow
,
tGS_cRow
,
tiled_g2s
,
tSR_sRow
,
tSR_rRow
,
args
.
tCcD
,
args
.
residue_cD
,
ThreadCount
{},
params
);
}
};
...
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
>
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
params
(
params
)
{}
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
int
m
;
CUTLASS_DEVICE
void
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
...
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_
aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
copy_
if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
...
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
params
);
}
};
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
e661d594
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
cutlass
::
arch
::
OpMultiplyAdd
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
...
...
@@ -473,20 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
half_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -501,11 +42,11 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm75_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -518,11 +59,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -537,11 +79,11 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -550,23 +92,17 @@ template <template <typename, typename> typename Epilogue,
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
half_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
...
...
@@ -574,17 +110,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -600,10 +132,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
0 → 100644
View file @
e661d594
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
vllm
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
FP8MathOperator
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
inline
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_default
{
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M256
{
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M64
{
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M32
{
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm75_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm75_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM256
=
typename
sm75_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128
=
Cutlass2xGemmDefault
;
using
Cutlass2xGemmM64
=
typename
sm75_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm75_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using
FallbackGemm
=
Cutlass2xGemmDefault
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// M in [1, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM256
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_fp8_fallback_gemm
{
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
;
};
struct
sm89_fp8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M128
{
// M in (64, 128]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8196
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
static
const
int32_t
MainLoopStages
=
5
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
24576
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_fp8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_fp8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_fp8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_fp8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_fp8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_fp8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_int8_fallback_gemm
{
// Shared mem requirement : 61440
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
static
int32_t
const
MainLoopStages
=
5
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
struct
sm89_int8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M128
{
// M in (64, 128]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_int8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_int8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_int8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_int8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_int8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_int8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
e661d594
...
...
@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
...
...
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleBDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
0
/*
Stages
*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
...
...
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
BiasDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
ElementD
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
BiasDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
0
/*
Stages
*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
...
...
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
...
...
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
e661d594
...
...
@@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
89
)
{
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
// and are slower than torch.mm. Return false unconditionally in this case.
return
false
;
// Once the CUTLASS kernels have been optimized for Lovelace systems,
// use the following check:
// return CUDA_VERSION >= 12040;
return
CUDA_VERSION
>=
12040
;
}
#endif
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
e661d594
...
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/fp8/common.cu
View file @
e661d594
...
...
@@ -48,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
...
...
csrc/quantization/fp8/nvidia/quant_utils.cuh
View file @
e661d594
...
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
// float -> fp8
...
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
e661d594
...
...
@@ -21,6 +21,7 @@
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
...
...
@@ -59,24 +60,27 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
bool
is_k_full
,
bool
has_zp
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
...
...
@@ -532,16 +536,18 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
...
...
@@ -595,6 +601,8 @@ __global__ void Marlin(
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
...
...
@@ -602,6 +610,7 @@ __global__ void Marlin(
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
// Compute all information about the current slice which is required for
...
...
@@ -632,6 +641,7 @@ __global__ void Marlin(
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
par_id
++
;
}
};
init_slice
();
...
...
@@ -1120,44 +1130,53 @@ __global__ void Marlin(
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
if
constexpr
(
!
has_zp
)
{
return
;
}
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
has_zp
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
sh_zp_stage
+
=
cur_group_id
*
zp_sh_strid
e
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pip
e
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
};
...
...
@@ -1321,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
...
...
@@ -1382,6 +1401,53 @@ __global__ void Marlin(
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
...
...
@@ -1606,7 +1672,11 @@ __global__ void Marlin(
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
...
...
@@ -1661,8 +1731,8 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr,
num_groups,
\
prob_m, prob_n, prob_k, locks
);
\
A_ptr, B_ptr, C_ptr,
C_tmp_ptr,
s_ptr, zp_ptr, g_idx_ptr, \
num_groups,
prob_m, prob_n, prob_k, locks
, use_fp32_reduce);
\
}
typedef
struct
{
...
...
@@ -1801,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
true
;
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
}
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
...
...
@@ -1880,18 +1971,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
use_fp32_reduce
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
// TODO: remove alias when we start supporting other 8bit types
int
num_bits
=
q_type
.
size_bits
();
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
...
...
@@ -1970,6 +2072,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
...
...
@@ -2042,18 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
||
*
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
->
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
->
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
...
...
@@ -2099,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// Alloc C tmp buffer that is going to be used for the global reduce
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
reduce_max_m
=
0
;
reduce_n
=
0
;
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
...
...
@@ -2169,22 +2293,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
marlin
::
marlin_mm
_f16i4
<
half
>
(
marlin
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm
_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BF
loat
16
>
(),
b_
zero
s
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
f
loat
>
(),
b_
scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
...
...
csrc/quantization/marlin/dense/common/base.h
0 → 100644
View file @
e661d594
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
Prev
1
2
3
4
5
6
7
8
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment