Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1698 additions
and
387 deletions
+1698
-387
csrc/core/math.hpp
csrc/core/math.hpp
+7
-0
csrc/cutlass_extensions/common.cpp
csrc/cutlass_extensions/common.cpp
+11
-0
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+35
-0
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+2
-0
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+4
-2
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+123
-1
csrc/ops.h
csrc/ops.h
+11
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+5
-4
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+3
-369
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
+160
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
+96
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
+140
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+2
-10
csrc/sparse/cutlass/sparse_compressor_c3x.cu
csrc/sparse/cutlass/sparse_compressor_c3x.cu
+165
-0
csrc/sparse/cutlass/sparse_compressor_entry.cu
csrc/sparse/cutlass/sparse_compressor_entry.cu
+42
-0
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
+303
-0
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
+496
-0
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
+70
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+22
-0
docs/requirements-docs.txt
docs/requirements-docs.txt
+1
-1
No files found.
csrc/core/math.hpp
0 → 100644
View file @
96ae75ad
#include <climits>
#include <iostream>
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
\ No newline at end of file
csrc/cutlass_extensions/common.cpp
0 → 100644
View file @
96ae75ad
#include "cutlass_extensions/common.hpp"
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
\ No newline at end of file
csrc/
quantization/cutlass_w8a8
/common.hpp
→
csrc/
cutlass_extensions
/common.hpp
View file @
96ae75ad
...
...
@@ -2,20 +2,27 @@
#include "cutlass/cutlass.h"
#include <climits>
#include "cuda_runtime.h"
#include <iostream>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
}
inline
int
get_cuda_max_shared_memory_per_block_opt_in
(
int
const
device
)
{
int
max_shared_mem_per_block_opt_in
=
0
;
...
...
@@ -25,3 +32,4 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
return
max_shared_mem_per_block_opt_in
;
}
int32_t
get_sm_version_num
();
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
96ae75ad
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
96ae75ad
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
...
...
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
96ae75ad
...
...
@@ -123,6 +123,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_global_mem_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
tokens_cnts
,
int32_t
*
cumsum
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
...
...
@@ -147,7 +233,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if
(
num_experts
>=
256
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
mem_tokens_cnts
=
((
num_experts
+
1
)
*
num_experts
)
*
sizeof
(
int32_t
);
const
int32_t
mem_cumsum
=
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
// allocate global memory
int32_t
*
tokens_cnts
;
int32_t
*
cumsum
;
cudaMalloc
(
&
tokens_cnts
,
mem_tokens_cnts
);
cudaMalloc
(
&
cumsum
,
mem_cumsum
);
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_global_mem_kernel
<
scalar_t
>
;
kernel
<<<
1
,
num_thread
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
tokens_cnts
,
cumsum
);
cudaFree
(
tokens_cnts
);
cudaFree
(
cumsum
);
});
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
int32_t
shared_mem_normal
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
...
...
@@ -185,6 +305,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
}
});
}
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
...
...
csrc/ops.h
View file @
96ae75ad
...
...
@@ -303,6 +303,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_compress_entry
(
torch
::
Tensor
&
a_compressed
,
torch
::
Tensor
&
e
,
torch
::
Tensor
const
&
a
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
96ae75ad
...
...
@@ -21,15 +21,16 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "common.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
using
namespace
cute
;
/*
Epilogue
functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues
must contain a public type named EVTCompute of type Sm80EVT,
Epilogue
s defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/all.h>
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <iostream>
#include <sstream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using
namespace
vllm
;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
StrideC
,
4
,
ElementD
,
StrideD
,
4
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
16
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
16
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
0 → 100644
View file @
96ae75ad
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using
namespace
cute
;
namespace
vllm
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
StrideC
,
4
,
ElementD
,
StrideD
,
4
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
16
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
16
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
0 → 100644
View file @
96ae75ad
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
0 → 100644
View file @
96ae75ad
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
96ae75ad
...
...
@@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
}
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
...
...
csrc/sparse/cutlass/sparse_compressor_c3x.cu
0 → 100644
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
ElementA_
,
typename
ElementAcc_
>
bool
cutlass_sparse_compress
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using
ElementA
=
ElementA_
;
using
LayoutTagA
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// B matrix configuration
using
ElementB
=
ElementA
;
using
LayoutTagB
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// C/D matrix configuration
using
ElementC
=
float
;
using
LayoutTagC
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Core kernel configurations
using
ElementAccumulator
=
ElementAcc_
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShapeRef
=
Shape
<
_128
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementA
,
cutlass
::
float_e4m3_t
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>::
type
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
ProblemShape
=
Shape
<
int
,
int
,
int
,
int
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutTagC
,
AlignmentC
,
ElementC
,
LayoutTagC
,
AlignmentC
,
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassSparseTensorOp
,
ElementA
,
LayoutTagA
,
AlignmentA
,
ElementB
,
LayoutTagB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
StrideA
=
cutlass
::
gemm
::
TagToStrideA_t
<
LayoutTagA
>
;
using
StrideE
=
StrideA
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
// The n (=1) dimension does not matter for the compressor
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
using
LayoutA
=
typename
GemmKernel
::
CollectiveMainloop
::
LayoutA
;
using
LayoutE
=
typename
GemmKernel
::
CollectiveMainloop
::
LayoutE
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
// Offline compressor kernel
using
CompressorUtility
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressorUtility
<
ProblemShape
,
ElementA
,
LayoutTagA
,
SparseConfig
>
;
using
CompressorKernel
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressor
<
ProblemShape
,
ElementA
,
LayoutTagA
,
SparseConfig
,
cutlass
::
arch
::
Sm90
>
;
using
Compressor
=
cutlass
::
transform
::
device
::
TransformUniversalAdapter
<
CompressorKernel
>
;
auto
[
M
,
N
,
K
,
L
]
=
prob_shape
;
StrideA
stride_A
;
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
M
,
K
,
L
));
CompressorUtility
compressor_utility
(
prob_shape
,
stride_A
);
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
typename
Gemm
::
CollectiveMainloop
::
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
0
;
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
stride_A
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
get
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
true
;
}
bool
cutlass_sparse_compress_sm90
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
if
(
a
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_sparse_compress
<
cutlass
::
bfloat16_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
return
cutlass_sparse_compress
<
cutlass
::
half_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
return
cutlass_sparse_compress
<
cutlass
::
float_e4m3_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
return
cutlass_sparse_compress
<
int8_t
,
int32_t
>
(
a_nzs
,
a_meta
,
a
);
}
return
false
;
}
#endif
csrc/sparse/cutlass/sparse_compressor_entry.cu
0 → 100644
View file @
96ae75ad
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool
cutlass_sparse_compress_sm90
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
);
#endif
bool
cutlass_sparse_compress_entry
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
a_meta
.
dim
()
==
2
&&
a_nzs
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
size
(
0
)
==
a_nzs
.
size
(
0
)
&&
a
.
size
(
0
)
==
a_meta
.
size
(
0
)
&&
a_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a_meta
.
size
(
1
)
*
2
*
4
==
a
.
size
(
1
));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
a_nzs
.
stride
(
1
)
==
1
&&
a_meta
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
0
)
%
8
==
0
);
// 8 Byte Alignment for Compression
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
return
cutlass_sparse_compress_sm90
(
a_nzs
,
a_meta
,
a
);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
0 → 100644
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM256
=
typename
sm90_fp8_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM512
=
typename
sm90_fp8_config_M512
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm1
=
typename
sm90_fp8_config_1
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm2
=
typename
sm90_fp8_config_2
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm3
=
typename
sm90_fp8_config_3
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm4
=
typename
sm90_fp8_config_4
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm5
=
typename
sm90_fp8_config_5
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm6
=
typename
sm90_fp8_config_6
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm7
=
typename
sm90_fp8_config_7
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm8
=
typename
sm90_fp8_config_8
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
bt_nzs
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
// Batch size
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm2
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
4096
||
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm1
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
128
)
{
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm3
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm5
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm4
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
256
)
{
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm6
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm8
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm7
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
if
(
n
==
6144
||
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm8
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm7
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
// Otherwise the default heuristic
if
(
mp2
<=
64
)
{
// n in [1, 64]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// n in (64, 128]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// n in (128, 256]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM256
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// n in (256, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM512
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp16_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
half_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_bf16_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
bfloat16_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_sparse_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp16_dispatch
<
cutlass
::
half_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp16_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
// a.dtype() == torch::kBFloat16
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_bf16_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_bf16_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
);
}
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
0 → 100644
View file @
96ae75ad
This diff is collapsed.
Click to expand it.
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
0 → 100644
View file @
96ae75ad
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
)
{
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return
CUDA_VERSION
>=
12020
&&
cuda_device_capability
>=
90
;
#endif
return
false
;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
bt_nzs
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
1
)
==
bt_nzs
.
size
(
0
)
&&
bt_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a
.
size
(
0
)
==
c
.
size
(
0
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
bt_nzs
.
size
(
0
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
bt_nzs
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
bt_nzs
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
bt_nzs
.
size
(
0
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_sparse_mm_sm90
(
c
,
a
,
bt_nzs
,
bt_meta
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
csrc/torch_bindings.cpp
View file @
96ae75ad
...
...
@@ -511,6 +511,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops
.
def
(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_sparse_scaled_mm_supported"
,
&
cutlass_sparse_scaled_mm_supported
);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops
.
def
(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_sparse_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_sparse_mm
);
// CUTLASS sparse matrix compressor
ops
.
def
(
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
" Tensor a) -> bool"
);
ops
.
impl
(
"cutlass_sparse_compress_entry"
,
&
cutlass_sparse_compress_entry
);
// Mamba selective scan kernel
ops
.
def
(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
...
...
docs/requirements-docs.txt
View file @
96ae75ad
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==
2
.0.
0
myst-parser==
3
.0.
1
sphinx-argparse==0.4.0
msgspec
cloudpickle
...
...
Prev
1
2
3
4
5
6
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment