Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f48954a4
Commit
f48954a4
authored
Jun 12, 2024
by
zhuwenwen
Browse files
merge v0.5.0
parents
1dba29d3
8f89d720
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
810 additions
and
314 deletions
+810
-314
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+0
-8
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+1
-1
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+18
-11
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+12
-0
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+3
-3
csrc/ops.h
csrc/ops.h
+40
-35
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+6
-6
csrc/punica/punica_ops.cu
csrc/punica/punica_ops.cu
+3
-3
csrc/punica/punica_ops.h
csrc/punica/punica_ops.h
+3
-3
csrc/punica/punica_pybind.cpp
csrc/punica/punica_pybind.cpp
+0
-13
csrc/punica/torch_bindings.cpp
csrc/punica/torch_bindings.cpp
+18
-0
csrc/pybind.cpp
csrc/pybind.cpp
+0
-111
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+1
-1
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+4
-4
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+72
-16
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+28
-22
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+389
-0
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
+75
-46
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
+126
-29
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
+11
-2
No files found.
csrc/moe/moe_ops.cpp
deleted
100644 → 0
View file @
1dba29d3
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
View file @
f48954a4
#pragma once
#include <torch/
extension
.h>
#include <torch/
all
.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
f48954a4
...
...
@@ -16,18 +16,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace
vllm
{
namespace
moe
{
static
constexpr
int
WARP_SIZE
=
32
;
/// Aligned array type
template
<
typename
T
,
...
...
@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
thread_max
=
max
(
thread_max
,
__shfl_xor_sync
(
0xFFFFFFFF
,
thread_max
,
mask
,
THREADS_PER_ROW
));
thread_max
=
max
(
thread_max
,
VLLM_SHFL_XOR_SYNC_WIDTH
(
thread_max
,
mask
,
THREADS_PER_ROW
));
}
// From this point, thread max in all the threads have the max within the row.
...
...
@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
row_sum
+=
__shfl_xor_sync
(
0xFFFFFFFF
,
row_sum
,
mask
,
THREADS_PER_ROW
);
row_sum
+=
VLLM_SHFL_XOR_SYNC_WIDTH
(
row_sum
,
mask
,
THREADS_PER_ROW
);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
...
...
@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
float
other_max
=
__shfl_xor_sync
(
0xFFFFFFFF
,
max_val
,
mask
,
THREADS_PER_ROW
);
int
other_expert
=
__shfl_xor_sync
(
0xFFFFFFFF
,
expert
,
mask
,
THREADS_PER_ROW
);
float
other_max
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
max_val
,
mask
,
THREADS_PER_ROW
);
int
other_expert
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
expert
,
mask
,
THREADS_PER_ROW
);
// We want lower indices to "win" in every thread so we break ties this way
if
(
other_max
>
max_val
||
(
other_max
==
max_val
&&
other_expert
<
expert
))
...
...
@@ -383,7 +390,7 @@ struct TopkConstants
{
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
,
""
);
static
constexpr
int
VECs_PER_THREAD
=
std
::
max
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
static
constexpr
int
THREADS_PER_ROW
=
EXPERTS
/
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
WARP_SIZE
/
THREADS_PER_ROW
;
...
...
@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
std
::
min
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
...
...
csrc/moe/torch_bindings.cpp
0 → 100644
View file @
f48954a4
#include "registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/moe_align_block_size_kernels.cu
View file @
f48954a4
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ATen.h>
...
...
@@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
64_t
num_experts
,
int
64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
csrc/ops.h
View file @
f48954a4
#pragma once
#include <torch/
extension
.h>
#include <torch/
library
.h>
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
);
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
kv_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
);
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
kv_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
float
epsilon
);
double
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
key
,
int
64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
key
,
int
64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
int
64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
int
64_t
split_k_iters
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
);
torch
::
Tensor
_zeros
,
int
64_t
split_k_iters
,
int
64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
...
...
@@ -88,14 +90,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
float
scale
);
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
...
...
@@ -103,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
bool
use_exllama
,
int
64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
64_t
bit
);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
...
...
@@ -113,28 +118,28 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
64_t
num_experts
,
int
64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
using
fptr_t
=
u
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
64_t
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
64_t
max_size
,
int
64_t
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int
meta_size
();
int
64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
...
...
csrc/pos_encoding_kernels.cu
View file @
f48954a4
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -127,7 +127,7 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
int
64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
...
...
@@ -138,7 +138,7 @@ void rotary_embedding(
int64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
...
...
@@ -168,9 +168,9 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
int
64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int
rot_dim
,
bool
is_neox
,
int
64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
)
{
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
...
...
@@ -180,7 +180,7 @@ void batched_rotary_embedding(
int64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
...
...
csrc/punica/punica_ops.cu
View file @
f48954a4
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
...
...
@@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
}
void
dispatch_bgmv
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
)
{
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
)
{
CHECK_INPUT
(
y
);
CHECK_INPUT
(
x
);
CHECK_INPUT
(
w
);
...
...
@@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
void
dispatch_bgmv_low_level
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
,
int64_t
h_in
,
int64_t
h_out
,
double
scale
,
int64_t
h_in
,
int64_t
h_out
,
int64_t
y_offset
)
{
CHECK_INPUT
(
y
);
CHECK_INPUT
(
x
);
...
...
csrc/punica/punica_ops.h
View file @
f48954a4
#pragma once
#include <torch/
extension
.h>
#include <torch/
all
.h>
void
dispatch_bgmv
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
);
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
double
scale
);
void
dispatch_bgmv_low_level
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
,
int64_t
h_in
,
int64_t
h_out
,
double
scale
,
int64_t
h_in
,
int64_t
h_out
,
int64_t
y_offset
);
csrc/punica/punica_pybind.cpp
deleted
100644 → 0
View file @
1dba29d3
#include <torch/extension.h>
#include "punica_ops.h"
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv"
,
&
dispatch_bgmv
,
"dispatch_bgmv"
);
m
.
def
(
"dispatch_bgmv_low_level"
,
&
dispatch_bgmv_low_level
,
"dispatch_bgmv_low_level"
);
}
csrc/punica/torch_bindings.cpp
0 → 100644
View file @
f48954a4
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()"
);
m
.
impl
(
"dispatch_bgmv"
,
torch
::
kCUDA
,
&
dispatch_bgmv
);
m
.
def
(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()"
);
m
.
impl
(
"dispatch_bgmv_low_level"
,
torch
::
kCUDA
,
&
dispatch_bgmv_low_level
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/pybind.cpp
deleted
100644 → 0
View file @
1dba29d3
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// vLLM custom ops
pybind11
::
module
ops
=
m
.
def_submodule
(
"ops"
,
"vLLM custom operators"
);
// Attention ops
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
// Activation ops
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
// Layernorm
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
// Rotary embedding
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)"
);
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"aqlm_gemm"
,
&
aqlm_gemm
,
"Quantized GEMM for AQLM"
);
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization."
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
// ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
// "Compute FP8 quantized tensor for given scaling factor");
// ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
// "Compute FP8 quantized tensor and scaling factor");
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size."
);
ops
.
def
(
"static_scaled_int8_quant"
,
&
static_scaled_int8_quant
,
"Compute int8 quantized tensor for given scaling factor"
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
"Convert the key and value cache to fp8 data type"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11
::
module
custom_ar
=
m
.
def_submodule
(
"custom_ar"
,
"custom allreduce"
);
custom_ar
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init_custom_ar"
);
custom_ar
.
def
(
"should_custom_ar"
,
&
should_custom_ar
,
"should_custom_ar"
);
custom_ar
.
def
(
"all_reduce_reg"
,
&
all_reduce_reg
,
"all_reduce_reg"
);
custom_ar
.
def
(
"all_reduce_unreg"
,
&
all_reduce_unreg
,
"all_reduce_unreg"
);
custom_ar
.
def
(
"dispose"
,
&
dispose
,
"dispose"
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
,
"meta_size"
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
,
"register_buffer"
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
,
"get_graph_buffer_ipc_meta"
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
csrc/quantization/aqlm/gemm_kernels.cu
View file @
f48954a4
...
...
@@ -18,7 +18,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
f48954a4
...
...
@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
...
...
@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
torch
::
Tensor
_zeros
,
int
64_t
split_k_iters
,
int
64_t
thx
,
int64_t
thy
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
...
...
@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
)
{
int
64_t
split_k_iters
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_in_feats
));
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
f48954a4
#include <ATen/cuda/CUDAContext.h>
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
...
...
@@ -27,33 +28,88 @@ namespace vllm {
template
<
typename
scalar_t
,
typename
scale_type
>
__global__
void
static_scaled_int8_quant_kernel
(
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
scale
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(((
float
)
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
}
}
template
<
typename
scalar_t
,
typename
scale_type
>
__global__
void
dynamic_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
float
const
block_absmax_val_maybe
=
blockReduceMax
(
absmax_val
);
__shared__
float
block_absmax_val
;
if
(
tid
==
0
)
{
block_absmax_val
=
block_absmax_val_maybe
;
scale
[
token_idx
]
=
block_absmax_val
/
127.0
f
;
}
__syncthreads
();
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
*
tmp_scale
);
}
}
}
// namespace vllm
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
float
scale
)
{
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
scale
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
TORCH_CHECK
(
scale
.
numel
()
==
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
,
hidden_size
);
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
});
}
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scales
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
});
}
csrc/quantization/cutlass_w8a8/
cutlass_visitor_2x_
broadcast_epilogue.hpp
→
csrc/quantization/cutlass_w8a8/broadcast_
load_
epilogue
_c2x
.hpp
View file @
f48954a4
...
...
@@ -33,20 +33,27 @@
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass It's beem modified to support either
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
// Important because this saves us a factor 4x on the number of kernels
// compiled.
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"
// clang-format on
namespace
cutlass
::
epilogue
::
threadblock
{
using
namespace
cute
;
...
...
@@ -59,9 +66,11 @@ template<
>
struct
VisitorRowOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
Element
null_default
=
Element
(
0
)
;
bool
row_broadcast
=
true
;
StrideMNL
dRow
=
{};
};
...
...
@@ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast {
auto
coord_v
=
filter
(
tC_cRow
);
auto
dst_v
=
filter
(
tC_rRow
);
if
(
params_ptr
->
ptr_row
)
{
if
(
params_ptr
->
row_broadcast
)
{
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
bool
guard
=
get
<
1
>
(
coord_v
(
i
))
<
n
;
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
}
}
else
{
// In this case we are loading from a scalar and broadcasting
VecType
filled_vec
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
VecLength
;
i
++
)
{
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
params_ptr
->
null_default
;
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
*
(
params_ptr
->
ptr_row
)
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
dst_v
(
i
)
=
filled_vec
;
}
}
...
...
@@ -208,9 +217,11 @@ template<
>
struct
VisitorColOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
Element
null_default
=
Element
(
0
)
;
bool
col_broadcast
=
true
;
StrideMNL
dCol
=
{};
};
...
...
@@ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast {
struct
SharedStorage
{
};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast
()
{
}
...
...
@@ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast {
int
m
;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rCol
);
...
...
@@ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast {
pred
(
i
)
=
get
<
0
>
(
tC_cCol
(
i
))
<
m
;
}
if
(
params_ptr
->
ptr_col
)
{
if
(
params_ptr
->
col_broadcast
)
{
// In this case we are loading from a column vector and broadcasting
copy_if
(
pred
,
tC_gCol
,
tC_rCol
);
}
else
{
...
...
@@ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
dst_v
);
++
i
)
{
if
(
pred
(
i
)){
dst_v
(
i
)
=
params_ptr
->
null_default
;
if
(
pred
(
i
))
{
dst_v
(
i
)
=
*
(
params_ptr
->
ptr_col
)
;
}
}
}
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
0 → 100644
View file @
f48954a4
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace
cutlass
::
epilogue
::
fusion
{
using
namespace
cute
;
using
namespace
detail
;
// Row vector broadcast
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_0
,
_1
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
bool
row_broadcast
=
true
;
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
),
smem_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem_row
.
data
()))
{
}
Params
params
;
Element
*
smem_row
;
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
true
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
params
.
ptr_row
==
nullptr
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
}
template
<
int
EpiTiles
,
class
RTensor
,
class
STensor
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
params
(
params
)
{}
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
tCrRow
,
*
(
params
.
ptr_row
));
return
;
}
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
));
}
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_row
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
tCrRow
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_row
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template
<
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_1
,
_0
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90ColOrScalarBroadcast
{
static_assert
(
Stages
==
0
,
"Column broadcast doesn't support smem usage yet"
);
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
_0
>>
)
||
// col vector broadcast, e.g. per-row alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
int
>>
));
// batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct
SharedStorage
{
};
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
bool
col_broadcast
=
true
;
StrideMNL
dCol
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
col_broadcast
&&
*
(
params
.
ptr_col
)
==
Element
(
0
));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
{
}
Params
params
;
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
()
{
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_col
;
Tensor
tCrCol_mn
=
tCrCol
(
_
,
_
,
_
,
epi_m
,
epi_n
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_col
[
i
]
=
tCrCol_mn
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_col
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
Tensor
mCol
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_col
),
make_shape
(
M
,
N
,
L
),
params
.
dCol
);
Tensor
tCgCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
}
};
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
View file @
f48954a4
#include <stddef.h>
#include <torch/
extension
.h>
#include <torch/
all
.h>
#include <ATen/cuda/CUDAContext.h>
...
...
@@ -22,7 +22,7 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "
cutlass_visitor_2x_
broadcast_epilogue.hpp"
#include "broadcast_
load_
epilogue
_c2x
.hpp"
#include "common.hpp"
// clang-format on
...
...
@@ -48,9 +48,44 @@ using namespace cute;
namespace
{
template
<
typename
Arch
,
typename
ElementAB_
,
typename
ElementD_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm75_to_sm80
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm80_to_sm89
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm89_to_sm90
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
static
void
invoke
(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel
::
invoke
(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
...
...
@@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ArchGuard
<
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
...
...
@@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
;
>::
GemmKernel
>
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
...
...
@@ -145,17 +180,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto
a_scales_ptr
=
a_scales
.
data_ptr
<
float
>
();
auto
b_scales_ptr
=
b_scales
.
data_ptr
<
float
>
();
// If A and B are quantized per-tensor, then these scale tensors are scalars,
// and they are passed in via the second argument.
using
ScaleAArgs
=
typename
Gemm
::
ScaleA
::
Arguments
;
ScaleAArgs
a_args
=
a_scales
.
numel
()
==
1
?
ScaleAArgs
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleAArgs
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
using
ScaleBArgs
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleBArgs
b_args
=
b_scales
.
numel
()
==
1
?
ScaleBArgs
{
nullptr
,
b_scales
.
item
<
float
>
()
,
{}}
:
Scale
B
Args
{
b
_scales
.
data_ptr
<
float
>
(),
{}
,
{}};
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}}
;
Scale
A
Args
a_args
{
a
_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
Gemm
::
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
...
...
@@ -214,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm75
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm75
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
enable_sm75_to_sm80
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -241,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm80
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
enable_sm80_to_sm89
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -269,16 +298,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm89
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass
_2x_gemm
<
cutlass
::
arch
::
Sm89
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
...
...
@@ -286,15 +315,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
cutlass
::
float_e4m3_t
,
cutlass
::
b
float
16
_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float
_e4m3
_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
cutlass
::
arch
::
Sm89
,
enable_sm89_to_sm90
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
View file @
f48954a4
#include <torch/extension.h>
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
...
...
@@ -6,19 +12,20 @@
#include <sstream>
#include <vector>
// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp"
#include "common.hpp"
// clang-format on
...
...
@@ -44,6 +51,26 @@ using namespace cute;
namespace
{
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
...
...
@@ -61,7 +88,7 @@ struct cutlass_3x_gemm {
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90Col
OrScalar
Broadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
...
...
@@ -69,7 +96,7 @@ struct cutlass_3x_gemm {
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Row
OrScalar
Broadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
...
...
@@ -114,9 +141,9 @@ struct cutlass_3x_gemm {
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
cutlass
::
gemm
::
PersistentScheduler
>
>
;
struct
GemmKernel
:
public
KernelType
{};
};
...
...
@@ -162,13 +189,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
using
ScaleA_Args
=
typename
Gemm
::
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleA_Args
a_args
=
a_scales
.
numel
()
==
1
?
ScaleA_Args
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleA_Args
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
ScaleB_Args
b_args
=
b_scales
.
numel
()
==
1
?
ScaleB_Args
{
nullptr
,
b_scales
.
item
<
float
>
(),
{}}
:
ScaleB_Args
{
b_scales
.
data_ptr
<
float
>
(),
{},
{}};
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
args
.
epilogue
.
thread
=
{
a_args
,
{
b_args
}};
...
...
@@ -178,14 +201,96 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
TORCH_CHECK
(
workspace_size
==
0
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
template
<
typename
InType
,
typename
OutType
,
int32_t
M
>
struct
sm90_fp8_config
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
>
struct
sm90_fp8_config
<
InType
,
OutType
,
128
>
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
>
struct
sm90_fp8_config
<
InType
,
OutType
,
64
>
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm
<
InType
,
OutType
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
template
<
typename
InType
,
typename
OutType
>
void
cutlass_scaled_mm_dq_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config
<
InType
,
OutType
,
0
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config
<
InType
,
OutType
,
64
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config
<
InType
,
OutType
,
128
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_scaled_mm_dq_dispatcher
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_scaled_mm_dq_dispatcher
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
// m in (128, inf)
return
cutlass_scaled_mm_dq_dispatcher
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -219,25 +324,17 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelCpAsyncWarpSpecializedCooperative
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
return
cutlass_scaled_mm_dq_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
return
cutlass_scaled_mm_dq_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
View file @
f48954a4
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/all.h>
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
...
...
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#endif
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
);
#else
cutlass_scaled_mm_dq_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
#endif
}
else
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_dq_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
);
...
...
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment