Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1748 additions
and
822 deletions
+1748
-822
csrc/punica/punica_ops.h
csrc/punica/punica_ops.h
+0
-11
csrc/punica/torch_bindings.cpp
csrc/punica/torch_bindings.cpp
+0
-18
csrc/punica/type_convert.h
csrc/punica/type_convert.h
+0
-82
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+0
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+1
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+0
-23
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+148
-90
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+28
-496
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+340
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
+123
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
...quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
+139
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
+368
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
+353
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+11
-19
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+1
-7
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+2
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+1
-1
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+3
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+198
-73
csrc/quantization/marlin/dense/common/base.h
csrc/quantization/marlin/dense/common/base.h
+32
-0
No files found.
csrc/punica/punica_ops.h
deleted
100644 → 0
View file @
6b16ea2e
#pragma once
#include <torch/all.h>
void
dispatch_bgmv
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
);
void
dispatch_bgmv_low_level
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
,
int64_t
h_in
,
int64_t
h_out
,
int64_t
y_offset
);
csrc/punica/torch_bindings.cpp
deleted
100644 → 0
View file @
6b16ea2e
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()"
);
m
.
impl
(
"dispatch_bgmv"
,
torch
::
kCUDA
,
&
dispatch_bgmv
);
m
.
def
(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()"
);
m
.
impl
(
"dispatch_bgmv_low_level"
,
torch
::
kCUDA
,
&
dispatch_bgmv_low_level
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/punica/type_convert.h
deleted
100644 → 0
View file @
6b16ea2e
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef
__half
nv_half
;
typedef
__hip_bfloat16
nv_bfloat16
;
typedef
__hip_bfloat162
nv_bfloat162
;
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat162
make_bfloat162
(
__hip_bfloat16
val
)
{
return
__hip_bfloat162
{
val
,
val
};
}
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat162
make_bfloat162
(
__hip_bfloat16
vall
,
__hip_bfloat16
valr
)
{
return
__hip_bfloat162
{
vall
,
valr
};
}
template
<
typename
T_src
,
typename
T_dst
>
__TYPE_CONVERT__HOST_DEVICE__
inline
T_dst
convert_type
(
T_src
val
)
{
return
static_cast
<
T_dst
>
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
float
convert_type
<
__half
,
float
>
(
__half
val
)
{
return
__half2float
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__half
convert_type
<
float
,
__half
>
(
float
val
)
{
return
__float2half
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
float
convert_type
<
__hip_bfloat16
,
float
>
(
__hip_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat16
convert_type
<
float
,
__hip_bfloat16
>
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
template
<
typename
T
>
__TYPE_CONVERT__HOST_DEVICE__
inline
T
vllm_add
(
T
a
,
T
b
)
{
return
a
+
b
;
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__half
vllm_add
<
__half
>
(
__half
a
,
__half
b
)
{
return
__hadd
(
a
,
b
);
}
template
<
>
__TYPE_CONVERT__HOST_DEVICE__
inline
__hip_bfloat16
vllm_add
<
__hip_bfloat16
>
(
__hip_bfloat16
a
,
__hip_bfloat16
b
)
{
return
__hadd
(
a
,
b
);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
csrc/quantization/aqlm/gemm_kernels.cu
View file @
e661d594
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
}
__syncthreads
();
__syncthreads
();
float
res
=
0
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
while
(
iters
--
)
{
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
...
...
csrc/quantization/awq/dequantize.cuh
View file @
e661d594
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
...
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return
result
;
return
result
;
#endif
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
}
// namespace awq
}
// namespace awq
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
e661d594
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
...
@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace
vllm
{
namespace
vllm
{
namespace
awq
{
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
...
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
B_shared
[
32
*
(
N
+
8
)];
__shared__
half
scaling_factors_shared
[
N
];
__shared__
half
zeros_shared
[
N
];
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
j_factors1
=
((
OC
+
N
-
1
)
/
N
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
...
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
...
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t
B_loaded
=
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
// q * scale - zero * scale.
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
...
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__
void
__launch_bounds__
(
64
)
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
View file @
e661d594
...
@@ -64,8 +64,6 @@ using namespace detail;
...
@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast
// Row vector broadcast
template
<
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
int
Stages
,
class
CtaTileShapeMNK
,
class
CtaTileShapeMNK
,
class
Element
,
class
Element
,
...
@@ -73,14 +71,12 @@ template<
...
@@ -73,14 +71,12 @@ template<
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
>
struct
Sm90RowOrScalarBroadcast
{
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
Stages
==
0
,
"Row broadcast doesn't support smem usage"
);
static_assert
(
static_assert
(
is_static_v
<
decltype
(
take
<
0
,
2
>
(
StrideMNL
{}))
>
);
// batch stride can be dynamic or static
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
static_assert
(
take
<
0
,
2
>
(
StrideMNL
{})
==
Stride
<
_0
,
_1
>
{});
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
struct
SharedStorage
{
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
>
smem
;
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
};
};
// This struct has been modified to have a bool indicating that ptr_row is a
// This struct has been modified to have a bool indicating that ptr_row is a
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return
args
;
return
args
;
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
template
<
class
ProblemShape
>
static
size_t
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
,
:
params
(
params
)
smem
_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem
_row
.
data
()))
{
}
,
smem
(
const_cast
<
Element
*>
(
shared_storage
.
smem
.
data
()))
{
}
Params
params
;
Params
params
;
Element
*
smem
_row
;
Element
*
smem
=
nullptr
;
CUTLASS_DEVICE
bool
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
is_producer_load_needed
()
const
{
return
tru
e
;
return
fals
e
;
}
}
CUTLASS_DEVICE
bool
CUTLASS_DEVICE
bool
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
!
params
.
row_broadcast
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
}
}
template
<
int
EpiTiles
,
class
R
Tensor
,
class
STensor
>
template
<
class
GS_GTensor
,
class
GS_STensor
,
class
GS_CTensor
,
class
Tiled_G2S
,
class
SR_S
Tensor
,
class
S
R_R
Tensor
,
class
CTensor
,
class
ThrResidue
,
class
ThrNum
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
ConsumerStoreCallbacks
(
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
GS_GTensor
tGS_gRow_
,
GS_STensor
tGS_sRow_
,
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
GS_CTensor
tGS_cRow_
,
Tiled_G2S
tiled_g2s_
,
params
(
params
)
{}
SR_STensor
tSR_sRow_
,
SR_RTensor
tSR_rRow_
,
CTensor
tCcRow_
,
ThrResidue
residue_tCcRow_
,
ThrNum
thr_num_
,
Params
const
&
params_
)
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
:
tGS_gRow
(
tGS_gRow_
)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
,
tGS_sRow
(
tGS_sRow_
)
,
tGS_cRow
(
tGS_cRow_
)
,
tiled_G2S
(
tiled_g2s_
)
,
tSR_sRow
(
tSR_sRow_
)
,
tSR_rRow
(
tSR_rRow_
)
,
tCcRow
(
tCcRow_
)
,
residue_tCcRow
(
residue_tCcRow_
)
,
params
(
params_
)
{}
GS_GTensor
tGS_gRow
;
// (CPY,CPY_M,CPY_N)
GS_STensor
tGS_sRow
;
// (CPY,CPY_M,CPY_N)
GS_CTensor
tGS_cRow
;
// (CPY,CPY_M,CPY_N)
Tiled_G2S
tiled_G2S
;
SR_STensor
tSR_sRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor
tSR_rRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor
tCcRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue
residue_tCcRow
;
// (m, n)
ThrNum
thr_num
;
Params
const
&
params
;
Params
const
&
params
;
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
begin
(
)
{
if
(
!
params
.
row_broadcast
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
t
C
rRow
,
*
(
params
.
ptr_row
));
fill
(
t
SR_
rRow
,
*
(
params
.
ptr_row
));
return
;
return
;
}
}
auto
synchronize
=
[
&
]
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
thr_num
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
Tensor
tGS_gRow_flt
=
filter_zeros
(
tGS_gRow
);
Tensor
tGS_sRow_flt
=
filter_zeros
(
tGS_sRow
);
Tensor
tGS_cRow_flt
=
make_tensor
(
tGS_cRow
.
data
(),
make_layout
(
tGS_gRow_flt
.
shape
(),
tGS_cRow
.
stride
()));
for
(
int
i
=
0
;
i
<
size
(
tGS_gRow_flt
);
++
i
)
{
if
(
get
<
1
>
(
tGS_cRow_flt
(
i
))
>=
size
<
1
>
(
CtaTileShapeMNK
{}))
{
continue
;
// OOB of SMEM,
}
if
(
elem_less
(
tGS_cRow_flt
(
i
),
make_coord
(
get
<
0
>
(
residue_tCcRow
),
get
<
1
>
(
residue_tCcRow
))))
{
tGS_sRow_flt
(
i
)
=
tGS_gRow_flt
(
i
);
}
else
{
tGS_sRow_flt
(
i
)
=
Element
(
0
);
// Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize
();
}
CUTLASS_DEVICE
void
begin_loop
(
int
epi_m
,
int
epi_n
)
{
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
if
(
!
params
.
row_broadcast
)
return
;
// Do not issue LDS when row is scalar
// (only works if 0-strides are in same location, which is by construction)
Tensor
tSR_sRow_flt
=
filter_zeros
(
tSR_sRow
(
_
,
_
,
_
,
epi_m
,
epi_n
));
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
Tensor
tSR_rRow_flt
=
filter_zeros
(
tSR_rRow
)
;
copy
_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
)
);
copy
(
tSR_sRow_flt
,
tSR_rRow_flt
);
}
}
}
}
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
t
C
rRow
(
epi_v
*
FragmentSize
+
i
);
frg_row
[
i
]
=
t
SR_
rRow
(
epi_v
*
FragmentSize
+
i
);
}
}
return
frg_row
;
return
frg_row
;
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
...
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
>
>
CUTLASS_DEVICE
auto
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
using
ThreadCount
=
decltype
(
size
(
args
.
tiled_copy
));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
Tensor
gRow
=
local_tile
(
mRow
(
_
,
_
,
l
),
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
));
// (CTA_M, CTA_N)
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem
),
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})),
make_shape
(
_0
{},
_1
{}));
// (CTA_M, CTA_N)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
//// G2S: Gmem to Smem
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
auto
tiled_g2s
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
Layout
<
Shape
<
_1
,
ThreadCount
>
,
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
Stride
<
_0
,
_1
>>
{},
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
Layout
<
_1
>
{});
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
auto
thr_g2s
=
tiled_g2s
.
get_slice
(
args
.
thread_idx
);
Tensor
tGS_gRow
=
thr_g2s
.
partition_S
(
gRow
);
Tensor
tGS_sRow
=
thr_g2s
.
partition_D
(
sRow
);
//// G2S: Coord
auto
cRow
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tGS_cRow
=
thr_g2s
.
partition_S
(
cRow
);
//// S2R: Smem to Reg
Tensor
tSR_sRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tSR_rRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tSR_sRow
));
// (CPY,CPY_M,CPY_N)
return
ConsumerStoreCallbacks
<
decltype
(
tGS_gRow
),
decltype
(
tGS_sRow
),
decltype
(
tGS_cRow
),
decltype
(
tiled_g2s
),
decltype
(
tSR_sRow
),
decltype
(
tSR_rRow
),
decltype
(
args
.
tCcD
),
decltype
(
args
.
residue_cD
),
ThreadCount
>
(
tGS_gRow
,
tGS_sRow
,
tGS_cRow
,
tiled_g2s
,
tSR_sRow
,
tSR_rRow
,
args
.
tCcD
,
args
.
residue_cD
,
ThreadCount
{},
params
);
}
}
};
};
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return
args
;
return
args
;
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
template
<
class
ProblemShape
>
static
size_t
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return
EmptyProducerLoadCallbacks
{};
return
EmptyProducerLoadCallbacks
{};
}
}
template
<
class
GTensor
,
class
RTensor
>
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
ConsumerStoreCallbacks
(
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
GTensor
&&
tCgCol
,
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
RTensor
&&
tCrCol
,
params
(
params
)
{}
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
Params
const
&
params
;
int
m
;
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
begin
()
{
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
return
;
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
// (only works if 0-strides are in same location, which is by construction)
copy_
aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
copy_
if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
...
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
// Generate an identity tensor matching the shape of the global tensor and
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
params
);
}
}
};
};
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
e661d594
#include <stddef.h>
#include <stddef.h>
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
cutlass
::
arch
::
OpMultiplyAdd
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
typename
...
EpilogueArgs
>
...
@@ -473,20 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -473,20 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
half_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -501,11 +42,11 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
...
@@ -501,11 +42,11 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm75_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -518,11 +59,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -518,11 +59,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -537,11 +79,11 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
...
@@ -537,11 +79,11 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -550,23 +92,17 @@ template <template <typename, typename> typename Epilogue,
...
@@ -550,23 +92,17 @@ template <template <typename, typename> typename Epilogue,
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
EpilogueArgs
&&
...
epilogue_args
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
cutlass_2x_gemm
<
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
else
{
}
else
{
...
@@ -574,17 +110,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -574,17 +110,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_caller
<
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_caller
<
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
half_t
,
Epilogue
>
(
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -600,10 +132,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
...
@@ -600,10 +132,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
0 → 100644
View file @
e661d594
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
vllm
{
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
}
};
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
FP8MathOperator
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Epilogue
=
Epilogue_
<
ElementD
,
OutputTileThreadMap
>
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
using
Epilogue
=
typename
Gemm
::
Epilogue
;
auto
evt_args
=
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...);
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
Gemm
,
typename
FallbackGemm
,
typename
...
EpilogueArgs
>
inline
void
fallback_cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static
const
int
max_shared_mem_per_block_opt_in
=
get_cuda_max_shared_memory_per_block_opt_in
(
0
);
size_t
const
gemm_shared_mem_size
=
sizeof
(
typename
Gemm
::
KernelType
::
SharedStorage
);
size_t
const
fallback_gemm_shared_mem_size
=
sizeof
(
typename
FallbackGemm
::
KernelType
::
SharedStorage
);
if
(
gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
)
{
return
cutlass_gemm_caller
<
Gemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
TORCH_CHECK
(
fallback_gemm_shared_mem_size
<=
max_shared_mem_per_block_opt_in
);
return
cutlass_gemm_caller
<
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_default
{
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M256
{
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M64
{
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm75_config_M32
{
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm75_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm75_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM256
=
typename
sm75_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128
=
Cutlass2xGemmDefault
;
using
Cutlass2xGemmM64
=
typename
sm75_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm75_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using
FallbackGemm
=
Cutlass2xGemmDefault
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// M in [1, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM256
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_default
{
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M64
{
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M32
{
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm80_config_M16
{
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm80_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass2xGemmDefault
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128BigN
=
typename
sm80_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM128SmallN
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM64
=
typename
sm80_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM32
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
using
Cutlass2xGemmM16
=
typename
sm80_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using
FallbackGemm
=
typename
sm80_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM16
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM32
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM64
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
small_n
=
n
<
8192
;
if
(
small_n
)
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128SmallN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmM128BigN
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
// M in (128, inf)
return
fallback_cutlass_gemm_caller
<
Cutlass2xGemmDefault
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_fp8_fallback_gemm
{
// Shared Memory required by this Gemm - 61440 bytes
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
;
};
struct
sm89_fp8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M128
{
// M in (64, 128]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8196
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAdd
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_fp8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
using
FP8MathOperator
=
typename
cutlass
::
arch
::
OpMultiplyAddFastAccum
;
static
const
int32_t
MainLoopStages
=
5
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
FallbackGemm
=
typename
sm89_fp8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
24576
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
MainLoopStages
,
FP8MathOperator
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_fp8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_fp8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_fp8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_fp8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_fp8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_fp8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
0 → 100644
View file @
e661d594
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
>
struct
sm89_int8_fallback_gemm
{
// Shared mem requirement : 61440
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
static
int32_t
const
MainLoopStages
=
5
;
using
Cutlass2xGemm
=
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
;
};
struct
sm89_int8_config_default
{
// M in (256, inf)
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M256
{
// M in (128, 256]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
4096
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M128
{
// M in (64, 128]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
np2
<=
16384
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M64
{
// M in (32, 64]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
3
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M32
{
// M in (16, 32]
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
struct
sm89_int8_config_M16
{
// M in [1, 16]
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
static
void
dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
using
FallbackGemm
=
typename
sm89_int8_fallback_gemm
<
InType
,
OutType
,
Epilogue
>::
Cutlass2xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
np2
=
next_pow_2
(
n
);
if
(
np2
<=
8192
)
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
using
TileShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
128
>
;
return
vllm
::
fallback_cutlass_gemm_caller
<
vllm
::
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
vllm
::
enable_sm89_to_sm90
,
InType
,
OutType
,
Epilogue
,
TileShape
,
WarpShape
,
InstructionShape
,
4
>
,
FallbackGemm
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm89_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
// M in [1, 16]
return
sm89_int8_config_M16
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
32
)
{
// M in (16, 32]
return
sm89_int8_config_M32
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// M in (32, 64]
return
sm89_int8_config_M64
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// M in (64, 128]
return
sm89_int8_config_M128
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// M in (128, 256]
return
sm89_int8_config_M256
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// M in (256, inf)
return
sm89_int8_config_default
::
dispatch
<
InType
,
OutType
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
e661d594
...
@@ -18,8 +18,6 @@
...
@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
...
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
...
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleBDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
0
/*
Stages
*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
};
};
/*
/*
...
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
...
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
BiasDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
ElementD
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
BiasDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
0
/*
Stages
*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
public:
...
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
I
nt
<
0
>
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
i
nt
64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
Int
<
0
>
{}
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
...
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
CUTLASS_CHECK
(
status
);
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
e661d594
...
@@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
...
@@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
if
(
cuda_device_capability
>=
90
)
{
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
89
)
{
}
else
if
(
cuda_device_capability
>=
89
)
{
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
return
CUDA_VERSION
>=
12040
;
// and are slower than torch.mm. Return false unconditionally in this case.
return
false
;
// Once the CUTLASS kernels have been optimized for Lovelace systems,
// use the following check:
// return CUDA_VERSION >= 12040;
}
}
#endif
#endif
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
e661d594
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
return
{};
// Squash missing return statement warning
}
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
return
{};
// Squash missing return statement warning
}
}
// The following macro is used to dispatch the conversion function based on
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/fp8/common.cu
View file @
e661d594
...
@@ -48,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
...
@@ -48,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
__shared__
float
cache
[
1024
];
int
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
// the current thread in cache[threadIdx.x]
...
...
csrc/quantization/fp8/nvidia/quant_utils.cuh
View file @
e661d594
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
...
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE
,
fp8_type
);
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
return
(
uint8_t
)
res
;
#endif
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// float -> fp8
// float -> fp8
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
...
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
}
#endif
#endif
assert
(
false
);
assert
(
false
);
__builtin_unreachable
();
// Suppress missing return statement warning
}
}
// The following macro is used to dispatch the conversion function based on
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
e661d594
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "marlin.cuh"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
...
@@ -59,24 +60,27 @@ __global__ void Marlin(
...
@@ -59,24 +60,27 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
bool
is_k_full
,
bool
has_zp
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
return
torch
::
empty
({
1
,
1
});
...
@@ -532,16 +536,18 @@ __global__ void Marlin(
...
@@ -532,16 +536,18 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// same size, which might involve multiple column "slices" (of width 16 *
...
@@ -595,6 +601,8 @@ __global__ void Marlin(
...
@@ -595,6 +601,8 @@ __global__ void Marlin(
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
if
(
slice_col_par
>=
n_tiles
)
{
...
@@ -602,6 +610,7 @@ __global__ void Marlin(
...
@@ -602,6 +610,7 @@ __global__ void Marlin(
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
}
// Compute all information about the current slice which is required for
// Compute all information about the current slice which is required for
...
@@ -632,6 +641,7 @@ __global__ void Marlin(
...
@@ -632,6 +641,7 @@ __global__ void Marlin(
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
locks
+=
n_tiles
;
slice_col
=
0
;
slice_col
=
0
;
par_id
++
;
}
}
};
};
init_slice
();
init_slice
();
...
@@ -1120,44 +1130,53 @@ __global__ void Marlin(
...
@@ -1120,44 +1130,53 @@ __global__ void Marlin(
};
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
if
constexpr
(
!
has_zp
)
{
// This code does not handle group_blocks == 0,
return
;
// which signifies act_order.
}
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
has_zp
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
0
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
sh_zp_stage
+
=
cur_group_id
*
zp_sh_strid
e
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pip
e
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
}
}
};
};
...
@@ -1321,7 +1340,7 @@ __global__ void Marlin(
...
@@ -1321,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
// results in FP16 (but still reduce with FP32 compute).
...
@@ -1382,6 +1401,53 @@ __global__ void Marlin(
...
@@ -1382,6 +1401,53 @@ __global__ void Marlin(
}
}
};
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
// in fragment layout.
...
@@ -1606,7 +1672,11 @@ __global__ void Marlin(
...
@@ -1606,7 +1672,11 @@ __global__ void Marlin(
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
}
if
(
last
)
// only the last block in a slice actually writes the result
if
(
last
)
// only the last block in a slice actually writes the result
...
@@ -1661,8 +1731,8 @@ __global__ void Marlin(
...
@@ -1661,8 +1731,8 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr,
num_groups,
\
A_ptr, B_ptr, C_ptr,
C_tmp_ptr,
s_ptr, zp_ptr, g_idx_ptr, \
prob_m, prob_n, prob_k, locks
);
\
num_groups,
prob_m, prob_n, prob_k, locks
, use_fp32_reduce);
\
}
}
typedef
struct
{
typedef
struct
{
...
@@ -1801,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
...
@@ -1801,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
true
;
return
true
;
}
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
}
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_act_order
,
bool
is_k_full
,
...
@@ -1880,18 +1971,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1880,18 +1971,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
zp
,
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
int
num_groups
,
int
group_size
,
int
dev
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
max_par
)
{
int
sms
,
int
max_par
,
bool
use_fp32_reduce
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
if
(
has_zp
)
{
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
// TODO: remove alias when we start supporting other 8bit types
int
num_bits
=
q_type
.
size_bits
();
int
tot_m
=
prob_m
;
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
...
@@ -1970,6 +2072,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
...
@@ -1970,6 +2072,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
...
@@ -2042,18 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
...
@@ -2042,18 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
}
}
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
bool
is_k_full
,
bool
has_zp
,
// Verify num_bits
bool
use_fp32_reduce
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
if
(
has_zp
)
{
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
||
*
b_q_type
==
vllm
::
kU8
,
int
pack_factor
=
32
/
num_bits
;
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
->
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
->
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
// Verify A
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
...
@@ -2099,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2099,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// Alloc C tmp buffer that is going to be used for the global reduce
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
reduce_max_m
=
0
;
reduce_n
=
0
;
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
// auto -1)
int
thread_k
=
-
1
;
int
thread_k
=
-
1
;
...
@@ -2169,22 +2293,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2169,22 +2293,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
marlin
::
marlin_mm
_f16i4
<
half
>
(
marlin
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm
_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BF
loat
16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
f
loat
>
(),
b_
zero
s
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
b_
scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
csrc/quantization/marlin/dense/common/base.h
0 → 100644
View file @
e661d594
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
Prev
1
2
3
4
5
6
7
8
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment