Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1698 additions
and
387 deletions
+1698
-387
csrc/core/math.hpp
csrc/core/math.hpp
+7
-0
csrc/cutlass_extensions/common.cpp
csrc/cutlass_extensions/common.cpp
+11
-0
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+35
-0
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+2
-0
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+4
-2
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+123
-1
csrc/ops.h
csrc/ops.h
+11
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+5
-4
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+3
-369
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
+160
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
+96
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
+140
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+2
-10
csrc/sparse/cutlass/sparse_compressor_c3x.cu
csrc/sparse/cutlass/sparse_compressor_c3x.cu
+165
-0
csrc/sparse/cutlass/sparse_compressor_entry.cu
csrc/sparse/cutlass/sparse_compressor_entry.cu
+42
-0
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
+303
-0
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
+496
-0
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
+70
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+22
-0
docs/requirements-docs.txt
docs/requirements-docs.txt
+1
-1
No files found.
csrc/core/math.hpp
0 → 100644
View file @
96ae75ad
#include <climits>
#include <iostream>
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
\ No newline at end of file
csrc/cutlass_extensions/common.cpp
0 → 100644
View file @
96ae75ad
#include "cutlass_extensions/common.hpp"
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
\ No newline at end of file
csrc/
quantization/cutlass_w8a8
/common.hpp
→
csrc/
cutlass_extensions
/common.hpp
View file @
96ae75ad
...
@@ -2,20 +2,27 @@
...
@@ -2,20 +2,27 @@
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include <climits>
#include <climits>
#include "cuda_runtime.h"
#include <iostream>
/**
/**
* Helper function for checking CUTLASS errors
* Helper function for checking CUTLASS errors
*/
*/
#define CUTLASS_CHECK(status) \
#define CUTLASS_CHECK(status) \
{ \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlass::Status error = status; \
cutlassGetStatusString(status)) \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
}
inline
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
/**
if
(
num
<=
1
)
return
num
;
* Panic wrapper for unwinding CUDA runtime errors
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
*/
}
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
}
inline
int
get_cuda_max_shared_memory_per_block_opt_in
(
int
const
device
)
{
inline
int
get_cuda_max_shared_memory_per_block_opt_in
(
int
const
device
)
{
int
max_shared_mem_per_block_opt_in
=
0
;
int
max_shared_mem_per_block_opt_in
=
0
;
...
@@ -25,3 +32,4 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
...
@@ -25,3 +32,4 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
return
max_shared_mem_per_block_opt_in
;
return
max_shared_mem_per_block_opt_in
;
}
}
int32_t
get_sm_version_num
();
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
96ae75ad
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
/*
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
96ae75ad
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
/*
...
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
...
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// This utility function constructs the arguments for the load descriptors
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
96ae75ad
...
@@ -123,6 +123,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -123,6 +123,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
}
}
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_global_mem_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
tokens_cnts
,
int32_t
*
cumsum
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
template
<
typename
scalar_t
,
int
TOPK
>
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
...
@@ -147,7 +233,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -147,7 +233,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if
(
num_experts
>=
256
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
mem_tokens_cnts
=
((
num_experts
+
1
)
*
num_experts
)
*
sizeof
(
int32_t
);
const
int32_t
mem_cumsum
=
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
// allocate global memory
int32_t
*
tokens_cnts
;
int32_t
*
cumsum
;
cudaMalloc
(
&
tokens_cnts
,
mem_tokens_cnts
);
cudaMalloc
(
&
cumsum
,
mem_cumsum
);
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_global_mem_kernel
<
scalar_t
>
;
kernel
<<<
1
,
num_thread
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
tokens_cnts
,
cumsum
);
cudaFree
(
tokens_cnts
);
cudaFree
(
cumsum
);
});
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
int32_t
shared_mem_normal
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
int32_t
shared_mem_normal
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
...
@@ -185,6 +305,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -185,6 +305,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
topk_ids
.
numel
());
}
}
});
});
}
}
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
...
...
csrc/ops.h
View file @
96ae75ad
...
@@ -303,6 +303,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
...
@@ -303,6 +303,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_compress_entry
(
torch
::
Tensor
&
a_compressed
,
torch
::
Tensor
&
e
,
torch
::
Tensor
const
&
a
);
#endif
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
96ae75ad
...
@@ -21,15 +21,16 @@
...
@@ -21,15 +21,16 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "common.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
// clang-format on
using
namespace
cute
;
using
namespace
cute
;
/*
/*
Epilogue
functions can be defined to post-process the output before it is
Epilogue
s defined in,
written to GPU memory.
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Epilogues
must contain a public type named EVTCompute of type Sm80EVT,
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
EVTCompute::Arguments struct.
*/
*/
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/all.h>
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include <iostream>
#include <sstream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
using
namespace
vllm
;
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
NVIDIA GPUs with sm90a (Hopper) or later.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
*/
namespace
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
StrideC
,
4
,
ElementD
,
StrideD
,
4
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
16
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
16
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
0 → 100644
View file @
96ae75ad
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using
namespace
cute
;
namespace
vllm
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
StrideC
,
4
,
ElementD
,
StrideD
,
4
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
16
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
16
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
0 → 100644
View file @
96ae75ad
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
0 → 100644
View file @
96ae75ad
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* Gemm shape.
*/
namespace
vllm
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
// For M > 128 and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_int8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
96ae75ad
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
...
@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
return
false
;
}
}
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
...
...
csrc/sparse/cutlass/sparse_compressor_c3x.cu
0 → 100644
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
ElementA_
,
typename
ElementAcc_
>
bool
cutlass_sparse_compress
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using
ElementA
=
ElementA_
;
using
LayoutTagA
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// B matrix configuration
using
ElementB
=
ElementA
;
using
LayoutTagB
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// C/D matrix configuration
using
ElementC
=
float
;
using
LayoutTagC
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Core kernel configurations
using
ElementAccumulator
=
ElementAcc_
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShapeRef
=
Shape
<
_128
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementA
,
cutlass
::
float_e4m3_t
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>::
type
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
ProblemShape
=
Shape
<
int
,
int
,
int
,
int
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutTagC
,
AlignmentC
,
ElementC
,
LayoutTagC
,
AlignmentC
,
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassSparseTensorOp
,
ElementA
,
LayoutTagA
,
AlignmentA
,
ElementB
,
LayoutTagB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
StrideA
=
cutlass
::
gemm
::
TagToStrideA_t
<
LayoutTagA
>
;
using
StrideE
=
StrideA
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
// The n (=1) dimension does not matter for the compressor
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
using
LayoutA
=
typename
GemmKernel
::
CollectiveMainloop
::
LayoutA
;
using
LayoutE
=
typename
GemmKernel
::
CollectiveMainloop
::
LayoutE
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
// Offline compressor kernel
using
CompressorUtility
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressorUtility
<
ProblemShape
,
ElementA
,
LayoutTagA
,
SparseConfig
>
;
using
CompressorKernel
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressor
<
ProblemShape
,
ElementA
,
LayoutTagA
,
SparseConfig
,
cutlass
::
arch
::
Sm90
>
;
using
Compressor
=
cutlass
::
transform
::
device
::
TransformUniversalAdapter
<
CompressorKernel
>
;
auto
[
M
,
N
,
K
,
L
]
=
prob_shape
;
StrideA
stride_A
;
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
M
,
K
,
L
));
CompressorUtility
compressor_utility
(
prob_shape
,
stride_A
);
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
typename
Gemm
::
CollectiveMainloop
::
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
0
;
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
stride_A
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
get
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
true
;
}
bool
cutlass_sparse_compress_sm90
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
if
(
a
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_sparse_compress
<
cutlass
::
bfloat16_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
return
cutlass_sparse_compress
<
cutlass
::
half_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
return
cutlass_sparse_compress
<
cutlass
::
float_e4m3_t
,
float
>
(
a_nzs
,
a_meta
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
return
cutlass_sparse_compress
<
int8_t
,
int32_t
>
(
a_nzs
,
a_meta
,
a
);
}
return
false
;
}
#endif
csrc/sparse/cutlass/sparse_compressor_entry.cu
0 → 100644
View file @
96ae75ad
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool
cutlass_sparse_compress_sm90
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
);
#endif
bool
cutlass_sparse_compress_entry
(
torch
::
Tensor
&
a_nzs
,
torch
::
Tensor
&
a_meta
,
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
a_meta
.
dim
()
==
2
&&
a_nzs
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
size
(
0
)
==
a_nzs
.
size
(
0
)
&&
a
.
size
(
0
)
==
a_meta
.
size
(
0
)
&&
a_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a_meta
.
size
(
1
)
*
2
*
4
==
a
.
size
(
1
));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
a_nzs
.
stride
(
1
)
==
1
&&
a_meta
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
0
)
%
8
==
0
);
// 8 Byte Alignment for Compression
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
return
cutlass_sparse_compress_sm90
(
a_nzs
,
a_meta
,
a
);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
0 → 100644
View file @
96ae75ad
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM256
=
typename
sm90_fp8_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM512
=
typename
sm90_fp8_config_M512
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm1
=
typename
sm90_fp8_config_1
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm2
=
typename
sm90_fp8_config_2
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm3
=
typename
sm90_fp8_config_3
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm4
=
typename
sm90_fp8_config_4
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm5
=
typename
sm90_fp8_config_5
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm6
=
typename
sm90_fp8_config_6
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm7
=
typename
sm90_fp8_config_7
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm8
=
typename
sm90_fp8_config_8
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
bt_nzs
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
// Batch size
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm2
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
4096
||
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm1
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
128
)
{
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm3
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm5
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm4
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
256
)
{
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm6
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm8
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm7
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
{
if
(
n
==
6144
||
n
==
28672
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm8
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
n
==
4096
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemm7
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
// Otherwise the default heuristic
if
(
mp2
<=
64
)
{
// n in [1, 64]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// n in (64, 128]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// n in (128, 256]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM256
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// n in (256, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM512
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_fp16_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
half_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_bf16_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
bfloat16_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_gemm_sm90_int8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
n
=
out
.
size
(
1
);
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM32NSmall
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM32NBig
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_sparse_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_sparse_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp16_dispatch
<
cutlass
::
half_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp16_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
// a.dtype() == torch::kBFloat16
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_bf16_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_bf16_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
);
}
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
0 → 100644
View file @
96ae75ad
This diff is collapsed.
Click to expand it.
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
0 → 100644
View file @
96ae75ad
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
)
{
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return
CUDA_VERSION
>=
12020
&&
cuda_device_capability
>=
90
;
#endif
return
false
;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
bt_nzs
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
1
)
==
bt_nzs
.
size
(
0
)
&&
bt_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a
.
size
(
0
)
==
c
.
size
(
0
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
bt_nzs
.
size
(
0
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
bt_nzs
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
bt_nzs
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
bt_nzs
.
size
(
0
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_sparse_mm_sm90
(
c
,
a
,
bt_nzs
,
bt_meta
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
csrc/torch_bindings.cpp
View file @
96ae75ad
...
@@ -511,6 +511,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -511,6 +511,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"
);
ops
.
def
(
"cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops
.
def
(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_sparse_scaled_mm_supported"
,
&
cutlass_sparse_scaled_mm_supported
);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops
.
def
(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_sparse_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_sparse_mm
);
// CUTLASS sparse matrix compressor
ops
.
def
(
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
" Tensor a) -> bool"
);
ops
.
impl
(
"cutlass_sparse_compress_entry"
,
&
cutlass_sparse_compress_entry
);
// Mamba selective scan kernel
// Mamba selective scan kernel
ops
.
def
(
ops
.
def
(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"selective_scan_fwd(Tensor! u, Tensor! delta,"
...
...
docs/requirements-docs.txt
View file @
96ae75ad
sphinx==6.2.1
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
sphinx-copybutton==0.5.2
myst-parser==
2
.0.
0
myst-parser==
3
.0.
1
sphinx-argparse==0.4.0
sphinx-argparse==0.4.0
msgspec
msgspec
cloudpickle
cloudpickle
...
...
Prev
1
2
3
4
5
6
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment