Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4d3a2c28
Commit
4d3a2c28
authored
Dec 30, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.5' into v0.6.5-dev
parents
92ec5d8e
2d1b9baa
Changes
430
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
814 additions
and
1221 deletions
+814
-1221
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+7
-5
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+5
-5
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+69
-35
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-15
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+82
-15
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+7
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+19
-5
csrc/ops.h
csrc/ops.h
+83
-106
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+39
-0
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+5
-160
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+52
-31
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+28
-14
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+27
-26
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+0
-302
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+7
-305
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+51
-23
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+7
-174
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+160
-0
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+6
-0
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+160
-0
No files found.
Too many changes to show.
To preserve performance only
430 of 430+
files are displayed.
Plain diff
Email patch
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
View file @
4d3a2c28
...
...
@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
...
...
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
View file @
4d3a2c28
...
...
@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx
_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_
k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
4
*
zp
_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_
m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
csrc/moe/marlin_moe_ops.cu
View file @
4d3a2c28
...
...
@@ -25,9 +25,12 @@
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
...
...
@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
64
,
128
},
// Reduce both 2X
};
thread_config_t
large_batch_thread_configs
[]
=
{
...
...
@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
64
,
64
,
128
},
// Reduce N 4X, same K
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
...
...
@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int
load_groups
=
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
return
load_groups
*
tb_n
*
4
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
...
@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION)
\
else if (KERNEL_FUNCTION(
q_type, thread_n_blocks, thread_k_blocks,
\
has_act_order, group
_blocks,
num_
thread
s,
blocks, \
max_shared_mem, stream, A_ptr, B_ptr, C_ptr,
\
sorted_ids_ptr, topk_weights_ptr, s_ptr,
g_idx_ptr,
\
expert_offsets_ptr, num_groups, expert_idx,
\
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
locks,
replicate_input, apply_weights, m_block, \
max_par,
exec_cfg.max_m_blocks)) { \
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(
\
q_type, thread_n
_blocks, thread
_k_
blocks,
has_act_order,
\
group_blocks, num_threads, blocks, max_shared_mem, stream,
\
A_ptr, B_ptr, C_ptr,
sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr,
expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m,
locks,
\
replicate_input, apply_weights, m_block,
max_par,
\
exec_cfg.max_m_blocks)) {
\
}
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
const
void
*
topk_ids
,
const
void
*
s
,
void
*
zp
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
...
...
@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
(((
group_size
==
-
1
||
group_size
==
0
)
?
1
:
prob_k
/
group_size
)
*
prob_n
/
8
)
*
expert_idx
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
num_groups
*
prob_n
/
8
*
expert_idx
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
+
num_groups
*
prob_n
/
(
pack_factor
*
4
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
...
...
@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
}
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
...
...
@@ -477,15 +482,24 @@ torch::Tensor marlin_gemm_moe(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
->
str
());
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
int
max_par
=
4
;
...
...
@@ -521,6 +535,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
TORCH_CHECK
(
VLLM_IMPLIES
(
!
is_k_full
,
has_act_order
),
"if is_k_full is false, has_act_order must be true"
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
...
...
@@ -542,13 +559,30 @@ torch::Tensor marlin_gemm_moe(
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
3
,
"b_zeros rank = "
,
rank
,
" is not 3"
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
num_groups
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
2
)
==
size_n
/
pack_factor
,
"b_zeros dim 2 = "
,
b_zeros
.
size
(
2
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
marlin_moe
::
marlin_mm_moe
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
92ec5d8e
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
csrc/moe_align_
block_size
_kernels.cu
→
csrc/moe
/moe
_align_
sum
_kernels.cu
View file @
4d3a2c28
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "
../
cuda_compat.h"
#include "
../
dispatch_utils.h"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX_SHARED_MEM_SIZE 64 * 1024
namespace
vllm
{
namespace
moe
{
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
...
...
@@ -37,14 +39,14 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t
*
tokens_cnts
=
nullptr
;
int32_t
*
cumsum
=
nullptr
;
if
(
experts_num_exceed_limit
)
{
// 2d tensor with shape (
num_experts
+ 1, num_experts)
// 2d tensor with shape (
blockDim.x
+ 1, num_experts)
tokens_cnts
=
global_tokens_cnts_ptr
;
// 1d tensor with shape (num_experts + 1)
cumsum
=
shared_mem
;
}
else
{
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (
num_experts
+ 1, num_experts)
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (
blockDim.x
+ 1, num_experts)
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
...
...
@@ -63,10 +65,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
...
...
@@ -89,9 +93,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
/**
...
...
@@ -116,6 +122,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., topk, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
x
=
0.0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
x
+=
VLLM_LDG
(
&
input
[
token_idx
*
TOPK
*
d
+
k
*
d
+
idx
]);
}
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
}
// namespace moe
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
...
...
@@ -125,7 +149,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
int32_t
shared_mem_normal
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
int32_t
shared_mem_normal
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
const
bool
experts_num_exceed_limit
=
shared_mem_normal
>
MAX_SHARED_MEM_SIZE
;
...
...
@@ -146,8 +171,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
}
else
{
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
,
false
>
;
...
...
@@ -159,6 +184,48 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
nullptr
,
num_experts
,
block_size
,
topk_ids
.
numel
());
}
});
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
num_tokens
=
output
.
numel
()
/
hidden_size
;
const
int
topk
=
input
.
size
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
topk
)
{
case
2
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
2
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
3
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
3
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
4
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
4
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
}
}
csrc/moe/moe_ops.h
View file @
4d3a2c28
...
...
@@ -5,3 +5,10 @@
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/moe/torch_bindings.cpp
View file @
4d3a2c28
#include "core/registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
...
...
@@ -9,16 +8,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor! input, Tensor output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
#ifndef USE_ROCM
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
// conditionally compiled so impl registration is in source file
#endif
}
...
...
csrc/ops.h
View file @
4d3a2c28
...
...
@@ -5,6 +5,30 @@
#include "core/scalar_type.hpp"
#include <vector>
torch
::
Tensor
weak_ref_tensor
(
torch
::
Tensor
&
tensor
)
{
// Ensure tensor is on CUDA
if
(
!
tensor
.
is_cuda
())
{
throw
std
::
runtime_error
(
"Tensor must be on CUDA device"
);
}
// Get the raw data pointer
void
*
data_ptr
=
tensor
.
data_ptr
();
// Get tensor sizes and strides
std
::
vector
<
int64_t
>
sizes
=
tensor
.
sizes
().
vec
();
std
::
vector
<
int64_t
>
strides
=
tensor
.
strides
().
vec
();
// Get tensor options (dtype, device)
auto
options
=
tensor
.
options
();
// Create a new tensor from the raw data pointer
auto
new_tensor
=
torch
::
from_blob
(
data_ptr
,
sizes
,
strides
,
options
);
return
new_tensor
;
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
...
@@ -158,6 +182,24 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
weight
,
torch
::
Tensor
&
scales
,
double
const
epsilon
,
std
::
optional
<
torch
::
Tensor
>
scale_ub
,
std
::
optional
<
torch
::
Tensor
>
residual
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
...
@@ -187,6 +229,9 @@ void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
double
threshold
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -231,62 +276,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
namespace
machete
{
std
::
vector
<
std
::
string
>
supported_schedules
(
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
);
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
};
// namespace machete
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
#endif
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
...
...
@@ -297,11 +288,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#ifndef USE_ROCM
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -316,14 +303,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -351,48 +330,46 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// c10::optional<torch::Tensor> const& scale_ub);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
std
::
vector
<
torch
::
Tensor
>
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
seq_idx_
,
const
c10
::
optional
<
at
::
Tensor
>&
initial_states_
,
const
c10
::
optional
<
at
::
Tensor
>&
final_states_out_
,
bool
silu_activation
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
csrc/opt/activation_kernels_opt.cu
View file @
4d3a2c28
...
...
@@ -107,8 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
fatrelu_kernel
(
const
T
&
x
,
const
float
threshold
)
{
const
float
f
=
(
float
)
x
;
return
(
T
)(
f
>
threshold
?
f
:
0.0
f
);
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
float
)>
__global__
void
act_and_mul_kernel_with_param
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
const
float
param
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
}
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
...
...
@@ -163,4 +196,10 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d],
torch
::
Tensor
&
input
,
// [..., 2 * d]
double
threshold
)
{
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM
(
vllm
::
fatrelu_kernel
,
threshold
);
}
\ No newline at end of file
csrc/opt/layernorm_kernels_opt.cu
View file @
4d3a2c28
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
...
...
@@ -55,154 +48,6 @@ __global__ void rms_norm_kernel(
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template
<
>
struct
_typeConvert
<
c10
::
Half
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__half
;
using
packed_hip_type
=
__half2
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__nv_bfloat16
;
using
packed_hip_type
=
__nv_bfloat162
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16Vec
{
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert
(
width
>
0
&&
(
width
&
(
width
-
1
))
==
0
,
"Width is not a positive power of 2!"
);
using
Converter
=
_typeConvert
<
scalar_t
>
;
using
T1
=
typename
Converter
::
hip_type
;
using
T2
=
typename
Converter
::
packed_hip_type
;
T1
data
[
width
];
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
temp_f
.
x
*=
scale
;
temp_f
.
y
*=
scale
;
T2
temp
=
Converter
::
convert
(
temp_f
);
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
data
[
i
]
=
Converter
::
convert
(
temp
);
}
}
return
*
this
;
}
__device__
float
sum_squares
()
const
{
float
result
=
0.0
f
;
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
x
=
Converter
::
convert
(
data
[
i
]);
result
+=
x
*
x
;
}
}
return
result
;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
...
...
csrc/prepare_inputs/advance_step.cu
View file @
4d3a2c28
...
...
@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
...
...
@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
&
t
,
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
const
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
...
...
@@ -77,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}
/// each thread processes a block per query
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
...
...
@@ -123,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
if
(
idx
==
0
)
{
paged_kv_indptr_ptr
[
idx
]
=
0
;
}
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
...
...
@@ -135,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_indices_ptr
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
max_num_blocks_per_seq
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
int
row
=
idx
/
block_tables_stride
;
int
col
=
idx
%
block_tables_stride
;
if
(
row
<
num_queries
&&
col
<
block_table_bound_ptr
[
row
])
{
paged_kv_indices_ptr
[
paged_kv_indptr_ptr
[
row
]
+
col
]
=
block_tables_ptr
[
row
*
block_tables_stride
+
col
];
// note: max_num_blocks_per_seq = block_tables.stride(0)
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if
(
num_queries
<
tid
&&
tid
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
tid
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if
(
num_queries
<
row
&&
row
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
row
]
=
paged_kv_indptr_ptr
[
num_queries
];
// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for
(
int
idx
=
tid
;
idx
<
(
num_seqs
*
max_num_blocks_per_seq
);
idx
+=
(
gridDim
.
x
*
blockDim
.
x
))
{
// block_tables-row = paged_kv_indptr[queryNum]
int
queryNum
=
idx
/
max_num_blocks_per_seq
;
int
col
=
idx
%
max_num_blocks_per_seq
;
if
(
queryNum
<
num_queries
&&
col
<
block_table_bound_ptr
[
queryNum
])
{
int
indices_arr_idx
=
paged_kv_indptr_ptr
[
queryNum
]
+
col
;
int
block_tables_idx
=
queryNum
*
max_num_blocks_per_seq
+
col
;
paged_kv_indices_ptr
[
indices_arr_idx
]
=
block_tables_ptr
[
block_tables_idx
];
}
}
}
...
...
@@ -211,7 +238,7 @@ void advance_step_flashinfer(
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %
d
\n
"
,
block_tables
.
stride
(
0
));
printf
(
" block_tables.stride(0) = %
zu
\n
"
,
block_tables
.
stride
(
0
));
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
...
...
@@ -236,22 +263,16 @@ void advance_step_flashinfer(
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
if
(
logging
)
{
printf
(
"launching kernel with %d blocks
\n
"
,
blocks
);
}
// TODO(will): support arbitrary block_tables stride
if
((
blocks
*
threads
)
/
block_tables
.
stride
(
0
)
<
num_queries
)
{
TORCH_CHECK
(
false
,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,"
,
" increasing the block size or take smaller steps."
,
" num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
int
block_tables_stride
=
block_tables
.
stride
(
0
);
TORCH_CHECK
((
blocks
*
threads
>
num_queries
),
"multi-step: not enough threads to map to num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
if
(
logging
)
{
printf
(
"launching kernels with %d blocks and %d threads
\n
"
,
blocks
,
threads
);
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
...
...
@@ -270,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
...
...
@@ -303,4 +324,4 @@ void advance_step_flashinfer(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
\ No newline at end of file
}
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
4d3a2c28
...
...
@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
}
}
...
...
@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
...
...
@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
float
val
=
static_cast
<
float
>
(
input
[
i
]);
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
...
...
@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
*
tmp_scale
);
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
}
}
...
...
@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int
const
token_idx
=
blockIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
val
=
static_cast
<
float
>
(
input
[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
...
...
@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
4d3a2c28
...
...
@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using
namespace
vllm
;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
...
...
@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
...
...
@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
...
...
@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
...
...
@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
return
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
...
...
@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
4d3a2c28
...
...
@@ -21,7 +21,6 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
...
...
@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
#endif
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrZeroLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrZeroBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
// it would technically work but no use case as data_ptr is never nullptr
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static_assert
(
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
return
Arguments
{
data_ptr
};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
protected:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzp
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzpToken
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
4d3a2c28
...
...
@@ -23,11 +23,12 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "
broadcast_load
_epilogue_c3x.hpp"
#include "
cutlass_extensions/epilogue/scaled_mm
_epilogue
s
_c3x.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
...
...
@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
#endif
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
&&
!
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
||
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
...
@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
4d3a2c28
...
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined
CUDA_VERSION && CUDA_VERSION >= 12000
#if defined
ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -114,26 +114,41 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
return
;
}
if
(
version_num
>=
75
)
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -174,25 +189,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#
if
defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#else
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
\ No newline at end of file
csrc/quantization/fp8/common.cu
View file @
4d3a2c28
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "common.cuh"
#include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#ifndef USE_ROCM
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template
<
typename
scalar_t
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t
tmp
=
0.0
;
while
(
i
<
num_elems
)
{
float
x
=
static_cast
<
float
>
(
input
[
i
]);
tmp
=
max
(
tmp
,
fabs
(
x
));
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
cache
[
threadIdx
.
x
]
=
tmp
;
__syncthreads
();
// Now perform parallel reduction within the thread block
int
ib
=
blockDim
.
x
/
2
;
while
(
ib
!=
0
)
{
if
(
threadIdx
.
x
<
ib
&&
cache
[
threadIdx
.
x
+
ib
]
>
cache
[
threadIdx
.
x
])
{
cache
[
threadIdx
.
x
]
=
cache
[
threadIdx
.
x
+
ib
];
}
__syncthreads
();
ib
/=
2
;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
template
<
typename
scalar_t
>
struct
__align__
(
8
)
vec4_t
{
scalar_t
x
;
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
typedef
struct
__align__
(
4
)
{
FP8_TYPE
x
;
FP8_TYPE
y
;
FP8_TYPE
z
;
FP8_TYPE
w
;
}
float8x4_t
;
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
...
...
@@ -204,8 +35,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
...
...
csrc/quantization/fp8/common.cuh
0 → 100644
View file @
4d3a2c28
#pragma once
#include "quantization/vectorization.cuh"
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
constexpr
static
auto
kFp8Type
=
c10
::
CppTypeToScalarType
<
FP8_TYPE
>::
value
;
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template
<
typename
scalar_t
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t
tmp
=
0.0
;
while
(
i
<
num_elems
)
{
float
x
=
static_cast
<
float
>
(
input
[
i
]);
tmp
=
max
(
tmp
,
fabs
(
x
));
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
cache
[
threadIdx
.
x
]
=
tmp
;
__syncthreads
();
// Now perform parallel reduction within the thread block
int
ib
=
blockDim
.
x
/
2
;
while
(
ib
!=
0
)
{
if
(
threadIdx
.
x
<
ib
&&
cache
[
threadIdx
.
x
+
ib
]
>
cache
[
threadIdx
.
x
])
{
cache
[
threadIdx
.
x
]
=
cache
[
threadIdx
.
x
+
ib
];
}
__syncthreads
();
ib
/=
2
;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
using
float8x4_t
=
q8x4_t
<
FP8_TYPE
>
;
// Vectorized input/output to better utilize memory bandwidth.
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/fp8/fp8_marlin.cu
View file @
4d3a2c28
...
...
@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
0 → 100644
View file @
4d3a2c28
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
has_residual
=
false
>
__device__
void
rms_norm_dynamic_per_token_quant_vec
(
scalar_out_t
*
__restrict__
out
,
// [..., hidden_size]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute rms
vllm
::
vectorized
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
// Compute scale
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
hidden_size
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
1.0
f
/
token_scale
,
hidden_size
,
residual
);
}
else
{
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
token_scale
,
hidden_size
,
residual
);
}
}
// RMS norm + quant kernel
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
has_residual
=
false
>
__global__
void
rms_norm_dynamic_per_token_quant_kernel
(
scalar_out_t
*
__restrict__
out
,
// [..., hidden_size]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
if
(
can_vectorize
)
{
return
rms_norm_dynamic_per_token_quant_vec
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
out
,
scales
,
input
,
weight
,
scale_ub
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
residual
);
}
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute RMS
vllm
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
// Compute Scale
vllm
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
hidden_size
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
1.0
f
/
token_scale
,
hidden_size
,
residual
);
}
else
{
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
token_scale
,
hidden_size
,
residual
);
}
}
}
// namespace vllm
// Residual add + RMS norm + dynamic per token
template
<
typename
scalar_in_t
>
void
rms_norm_dynamic_per_token_quant_dispatch
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
,
std
::
optional
<
at
::
Tensor
>&
residual
)
{
int32_t
hidden_size
=
input
.
size
(
-
1
);
int32_t
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
float
min_scaling_factor
=
out
.
dtype
()
==
torch
::
kInt8
?
std
::
numeric_limits
<
float
>::
epsilon
()
:
1.0
f
/
(
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
()
*
512.
f
);
if
(
residual
.
has_value
())
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
vllm
::
rms_norm_dynamic_per_token_quant_kernel
<
scalar_in_t
,
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
residual
->
data_ptr
<
scalar_in_t
>
());
});
}
else
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
vllm
::
rms_norm_dynamic_per_token_quant_kernel
<
scalar_in_t
,
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
nullptr
);
});
}
}
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
scale_ub
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
if
(
scale_ub
.
has_value
())
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
);
}
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_dispatch"
,
[
&
]
{
rms_norm_dynamic_per_token_quant_dispatch
<
scalar_t
>
(
out
,
input
,
weight
,
scales
,
var_epsilon
,
scale_ub
,
residual
);
});
}
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment