Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4d3a2c28
Commit
4d3a2c28
authored
Dec 30, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.5' into v0.6.5-dev
parents
92ec5d8e
2d1b9baa
Changes
430
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
814 additions
and
1221 deletions
+814
-1221
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+7
-5
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+5
-5
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+69
-35
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-15
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+82
-15
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+7
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+19
-5
csrc/ops.h
csrc/ops.h
+83
-106
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+39
-0
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+5
-160
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+52
-31
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+28
-14
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+27
-26
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+0
-302
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+7
-305
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+51
-23
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+7
-174
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+160
-0
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+6
-0
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+160
-0
No files found.
Too many changes to show.
To preserve performance only
430 of 430+
files are displayed.
Plain diff
Email patch
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
View file @
4d3a2c28
...
@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
...
@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
cfg_max_m_blocks
)
{
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
if
(
false
)
{
}
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
...
...
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
View file @
4d3a2c28
...
@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
...
@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx
_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
4
*
zp
_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_
k
,
int
tot_m
,
int
*
locks
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_
m
,
int
prob_n
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
cfg_max_m_blocks
);
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
}
csrc/moe/marlin_moe_ops.cu
View file @
4d3a2c28
...
@@ -25,9 +25,12 @@
...
@@ -25,9 +25,12 @@
#include <iostream>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template
<
typename
T
>
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
inline
std
::
string
str
(
T
x
)
{
...
@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
...
@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
64
,
128
},
// Reduce both 2X
};
};
thread_config_t
large_batch_thread_configs
[]
=
{
thread_config_t
large_batch_thread_configs
[]
=
{
...
@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
...
@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
64
,
64
,
128
},
// Reduce N 4X, same K
};
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
...
@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
...
@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int
load_groups
=
int
load_groups
=
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
return
load_groups
*
tb_n
*
4
;
}
else
{
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION)
\
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(
q_type, thread_n_blocks, thread_k_blocks,
\
else if (KERNEL_FUNCTION(
\
has_act_order, group
_blocks,
num_
thread
s,
blocks, \
q_type, thread_n
_blocks, thread
_k_
blocks,
has_act_order,
\
max_shared_mem, stream, A_ptr, B_ptr, C_ptr,
\
group_blocks, num_threads, blocks, max_shared_mem, stream,
\
sorted_ids_ptr, topk_weights_ptr, s_ptr,
g_idx_ptr,
\
A_ptr, B_ptr, C_ptr,
sorted_ids_ptr, topk_weights_ptr, s_ptr, \
expert_offsets_ptr, num_groups, expert_idx,
\
zp_ptr, g_idx_ptr,
expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m,
locks,
\
locks,
replicate_input, apply_weights, m_block, \
replicate_input, apply_weights, m_block,
max_par,
\
max_par,
exec_cfg.max_m_blocks)) { \
exec_cfg.max_m_blocks)) {
\
}
}
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
const
void
*
g_idx
,
const
void
*
topk_ids
,
const
void
*
s
,
void
*
zp
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
...
@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
...
@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
num_groups
*
prob_n
/
8
*
expert_idx
;
(
const
int4
*
)
s
+
const
int4
*
zp_ptr
=
(((
group_size
==
-
1
||
group_size
==
0
)
?
1
:
prob_k
/
group_size
)
*
(
const
int4
*
)
zp
+
num_groups
*
prob_n
/
(
pack_factor
*
4
)
*
expert_idx
;
prob_n
/
8
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
int
*
locks
=
(
int
*
)
workspace
;
...
@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
...
@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
}
}
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4
)
else
{
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
...
@@ -477,15 +482,24 @@ torch::Tensor marlin_gemm_moe(
...
@@ -477,15 +482,24 @@ torch::Tensor marlin_gemm_moe(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
bool
replicate_input
,
bool
apply_weights
)
{
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
->
str
());
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
int
max_par
=
4
;
int
max_par
=
4
;
...
@@ -521,6 +535,9 @@ torch::Tensor marlin_gemm_moe(
...
@@ -521,6 +535,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = "
,
size_n
);
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
num_groups
=
b_scales
.
size
(
1
);
TORCH_CHECK
(
VLLM_IMPLIES
(
!
is_k_full
,
has_act_order
),
"if is_k_full is false, has_act_order must be true"
);
if
(
has_act_order
)
{
if
(
has_act_order
)
{
if
(
is_k_full
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
...
@@ -542,13 +559,30 @@ torch::Tensor marlin_gemm_moe(
...
@@ -542,13 +559,30 @@ torch::Tensor marlin_gemm_moe(
}
}
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
3
,
"b_zeros rank = "
,
rank
,
" is not 3"
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
num_groups
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
2
)
==
size_n
/
pack_factor
,
"b_zeros dim 2 = "
,
b_zeros
.
size
(
2
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
marlin_moe
::
marlin_mm_moe
(
marlin_moe
::
marlin_mm_moe
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
num_experts
,
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
num_experts
,
topk
,
moe_block_size
,
dev
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
return
c
;
}
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
92ec5d8e
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
csrc/moe_align_
block_size
_kernels.cu
→
csrc/moe
/moe
_align_
sum
_kernels.cu
View file @
4d3a2c28
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#include "dispatch_utils.h"
#include "
../
dispatch_utils.h"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX_SHARED_MEM_SIZE 64 * 1024
#define MAX_SHARED_MEM_SIZE 64 * 1024
namespace
vllm
{
namespace
vllm
{
namespace
moe
{
namespace
{
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
...
@@ -37,14 +39,14 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -37,14 +39,14 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t
*
tokens_cnts
=
nullptr
;
int32_t
*
tokens_cnts
=
nullptr
;
int32_t
*
cumsum
=
nullptr
;
int32_t
*
cumsum
=
nullptr
;
if
(
experts_num_exceed_limit
)
{
if
(
experts_num_exceed_limit
)
{
// 2d tensor with shape (
num_experts
+ 1, num_experts)
// 2d tensor with shape (
blockDim.x
+ 1, num_experts)
tokens_cnts
=
global_tokens_cnts_ptr
;
tokens_cnts
=
global_tokens_cnts_ptr
;
// 1d tensor with shape (num_experts + 1)
// 1d tensor with shape (num_experts + 1)
cumsum
=
shared_mem
;
cumsum
=
shared_mem
;
}
else
{
}
else
{
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (
num_experts
+ 1, num_experts)
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (
blockDim.x
+ 1, num_experts)
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
}
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
...
@@ -63,10 +65,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -63,10 +65,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads
();
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -89,9 +93,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -89,9 +93,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
* blocks and stores the corresponding expert_id for each block.
*/
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
if
(
threadIdx
.
x
<
num_experts
)
{
i
+=
block_size
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
}
/**
/**
...
@@ -116,6 +122,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
...
@@ -116,6 +122,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., topk, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
x
=
0.0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
x
+=
VLLM_LDG
(
&
input
[
token_idx
*
TOPK
*
d
+
k
*
d
+
idx
]);
}
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
}
// namespace moe
}
// namespace vllm
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
...
@@ -125,7 +149,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -125,7 +149,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
int32_t
shared_mem_normal
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
int32_t
shared_mem_normal
=
((
num_thread
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
sizeof
(
int32_t
);
const
bool
experts_num_exceed_limit
=
shared_mem_normal
>
MAX_SHARED_MEM_SIZE
;
const
bool
experts_num_exceed_limit
=
shared_mem_normal
>
MAX_SHARED_MEM_SIZE
;
...
@@ -146,8 +171,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -146,8 +171,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int32_t
>
(),
num_experts
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
());
}
else
{
}
else
{
// set dynamic shared mem
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
,
false
>
;
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
,
false
>
;
...
@@ -159,6 +184,48 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -159,6 +184,48 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
nullptr
,
num_experts
,
block_size
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
nullptr
,
num_experts
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
());
}
}
});
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
num_tokens
=
output
.
numel
()
/
hidden_size
;
const
int
topk
=
input
.
size
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
topk
)
{
case
2
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
2
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
3
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
3
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
});
break
;
case
4
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
4
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
}
}
}
csrc/moe/moe_ops.h
View file @
4d3a2c28
...
@@ -5,3 +5,10 @@
...
@@ -5,3 +5,10 @@
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
torch
::
Tensor
&
gating_output
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/moe/torch_bindings.cpp
View file @
4d3a2c28
#include "core/registration.h"
#include "core/registration.h"
#include "moe_ops.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
// Apply topk softmax to the gating outputs.
...
@@ -9,16 +8,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -9,16 +8,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor! input, Tensor output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
#ifndef USE_ROCM
#ifndef USE_ROCM
m
.
def
(
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int b_q_type, SymInt size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
" -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
// conditionally compiled so impl registration is in source file
#endif
#endif
}
}
...
...
csrc/ops.h
View file @
4d3a2c28
...
@@ -5,6 +5,30 @@
...
@@ -5,6 +5,30 @@
#include "core/scalar_type.hpp"
#include "core/scalar_type.hpp"
#include <vector>
torch
::
Tensor
weak_ref_tensor
(
torch
::
Tensor
&
tensor
)
{
// Ensure tensor is on CUDA
if
(
!
tensor
.
is_cuda
())
{
throw
std
::
runtime_error
(
"Tensor must be on CUDA device"
);
}
// Get the raw data pointer
void
*
data_ptr
=
tensor
.
data_ptr
();
// Get tensor sizes and strides
std
::
vector
<
int64_t
>
sizes
=
tensor
.
sizes
().
vec
();
std
::
vector
<
int64_t
>
strides
=
tensor
.
strides
().
vec
();
// Get tensor options (dtype, device)
auto
options
=
tensor
.
options
();
// Create a new tensor from the raw data pointer
auto
new_tensor
=
torch
::
from_blob
(
data_ptr
,
sizes
,
strides
,
options
);
return
new_tensor
;
}
void
paged_attention_v1
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
@@ -158,6 +182,24 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
...
@@ -158,6 +182,24 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
weight
,
torch
::
Tensor
&
scales
,
double
const
epsilon
,
std
::
optional
<
torch
::
Tensor
>
scale_ub
,
std
::
optional
<
torch
::
Tensor
>
residual
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
@@ -187,6 +229,9 @@ void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
...
@@ -187,6 +229,9 @@ void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
double
threshold
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
@@ -231,62 +276,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
...
@@ -231,62 +276,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
namespace
machete
{
std
::
vector
<
std
::
string
>
supported_schedules
(
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
);
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
};
// namespace machete
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
#endif
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
int64_t
n
);
...
@@ -297,11 +288,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
...
@@ -297,11 +288,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
#ifndef USE_ROCM
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
@@ -316,14 +303,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
...
@@ -316,14 +303,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
@@ -351,48 +330,46 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...
@@ -351,48 +330,46 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// c10::optional<torch::Tensor> const& scale_ub);
// c10::optional<torch::Tensor> const& scale_ub);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
experts_ids
,
const
torch
::
Tensor
&
C
,
torch
::
Tensor
num_tokens_post_pad
);
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
std
::
vector
<
torch
::
Tensor
>
selective_scan_fwd
(
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
bool
delta_softplus
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
at
::
Tensor
causal_conv1d_update
(
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices
);
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
seq_idx_
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
initial_states_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
final_states_out_
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
bool
silu_activation
);
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
const
std
::
vector
<
std
::
string
>&
handles
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
bool
full_nvlink
);
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
const
std
::
vector
<
std
::
string
>&
handles
,
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
const
std
::
vector
<
int64_t
>&
offsets
);
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
void
register_graph_buffers
(
fptr_t
_fa
,
fptr_t
_fa
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
#endif
csrc/opt/activation_kernels_opt.cu
View file @
4d3a2c28
...
@@ -107,8 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -107,8 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
}
template
<
typename
T
>
__device__
__forceinline__
T
fatrelu_kernel
(
const
T
&
x
,
const
float
threshold
)
{
const
float
f
=
(
float
)
x
;
return
(
T
)(
f
>
threshold
?
f
:
0.0
f
);
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
float
)>
__global__
void
act_and_mul_kernel_with_param
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
const
float
param
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
}
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
...
@@ -163,4 +196,10 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
...
@@ -163,4 +196,10 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d],
torch
::
Tensor
&
input
,
// [..., 2 * d]
double
threshold
)
{
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM
(
vllm
::
fatrelu_kernel
,
threshold
);
}
}
\ No newline at end of file
csrc/opt/layernorm_kernels_opt.cu
View file @
4d3a2c28
#include <torch/all.h>
#include "type_convert.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#else
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
...
@@ -55,154 +48,6 @@ __global__ void rms_norm_kernel(
...
@@ -55,154 +48,6 @@ __global__ void rms_norm_kernel(
}
}
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template
<
>
struct
_typeConvert
<
c10
::
Half
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__half
;
using
packed_hip_type
=
__half2
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__nv_bfloat16
;
using
packed_hip_type
=
__nv_bfloat162
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16Vec
{
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert
(
width
>
0
&&
(
width
&
(
width
-
1
))
==
0
,
"Width is not a positive power of 2!"
);
using
Converter
=
_typeConvert
<
scalar_t
>
;
using
T1
=
typename
Converter
::
hip_type
;
using
T2
=
typename
Converter
::
packed_hip_type
;
T1
data
[
width
];
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
temp_f
.
x
*=
scale
;
temp_f
.
y
*=
scale
;
T2
temp
=
Converter
::
convert
(
temp_f
);
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
data
[
i
]
=
Converter
::
convert
(
temp
);
}
}
return
*
this
;
}
__device__
float
sum_squares
()
const
{
float
result
=
0.0
f
;
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
x
=
Converter
::
convert
(
data
[
i
]);
result
+=
x
*
x
;
}
}
return
result
;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
Additional optimizations we can make in this case are
...
...
csrc/prepare_inputs/advance_step.cu
View file @
4d3a2c28
...
@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
...
@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int64_t
const
block_tables_stride
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
...
@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
...
@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
&
t
,
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
const
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
bool
size_0_cond
=
true
;
...
@@ -77,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
...
@@ -77,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}
}
}
/// each thread processes a block per query
__global__
void
advance_step_flashinfer_kernel
(
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
...
@@ -123,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
...
@@ -123,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
// Update paged_kv_indptr
if
(
idx
==
0
)
{
paged_kv_indptr_ptr
[
idx
]
=
0
;
}
if
(
idx
<
num_queries
)
{
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
...
@@ -135,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
...
@@ -135,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}
}
__global__
void
advance_step_flashinfer_indices_kernel
(
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_indices_ptr
,
int64_t
const
max_num_blocks_per_seq
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// note: max_num_blocks_per_seq = block_tables.stride(0)
int
row
=
idx
/
block_tables_stride
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
col
=
idx
%
block_tables_stride
;
// when cuda graphs are enabled, paged_kv_indptr tensor
if
(
row
<
num_queries
&&
col
<
block_table_bound_ptr
[
row
])
{
// has to be updated for the padded queries
paged_kv_indices_ptr
[
paged_kv_indptr_ptr
[
row
]
+
col
]
=
// tid represents a query# for paged_kv_indptr tensor
block_tables_ptr
[
row
*
block_tables_stride
+
col
];
if
(
num_queries
<
tid
&&
tid
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
tid
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if
(
num_queries
<
row
&&
row
<=
num_seqs
)
{
// each thread processes a block_ptr in block_tables
paged_kv_indptr_ptr
[
row
]
=
paged_kv_indptr_ptr
[
num_queries
];
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for
(
int
idx
=
tid
;
idx
<
(
num_seqs
*
max_num_blocks_per_seq
);
idx
+=
(
gridDim
.
x
*
blockDim
.
x
))
{
// block_tables-row = paged_kv_indptr[queryNum]
int
queryNum
=
idx
/
max_num_blocks_per_seq
;
int
col
=
idx
%
max_num_blocks_per_seq
;
if
(
queryNum
<
num_queries
&&
col
<
block_table_bound_ptr
[
queryNum
])
{
int
indices_arr_idx
=
paged_kv_indptr_ptr
[
queryNum
]
+
col
;
int
block_tables_idx
=
queryNum
*
max_num_blocks_per_seq
+
col
;
paged_kv_indices_ptr
[
indices_arr_idx
]
=
block_tables_ptr
[
block_tables_idx
];
}
}
}
}
}
...
@@ -211,7 +238,7 @@ void advance_step_flashinfer(
...
@@ -211,7 +238,7 @@ void advance_step_flashinfer(
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %
d
\n
"
,
block_tables
.
stride
(
0
));
printf
(
" block_tables.stride(0) = %
zu
\n
"
,
block_tables
.
stride
(
0
));
}
}
// Verify all tensors
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
...
@@ -236,22 +263,16 @@ void advance_step_flashinfer(
...
@@ -236,22 +263,16 @@ void advance_step_flashinfer(
int
threads
;
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
if
(
logging
)
{
printf
(
"launching kernel with %d blocks
\n
"
,
blocks
);
}
// TODO(will): support arbitrary block_tables stride
int
block_tables_stride
=
block_tables
.
stride
(
0
);
if
((
blocks
*
threads
)
/
block_tables
.
stride
(
0
)
<
num_queries
)
{
TORCH_CHECK
((
blocks
*
threads
>
num_queries
),
TORCH_CHECK
(
false
,
"multi-step: not enough threads to map to num_queries = "
,
"multi-step: not enough threads to map block_table to"
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
"of seqs,"
,
if
(
logging
)
{
" increasing the block size or take smaller steps."
,
printf
(
"launching kernels with %d blocks and %d threads
\n
"
,
blocks
,
" num_queries = "
,
num_queries
,
threads
);
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
}
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
...
@@ -270,7 +291,7 @@ void advance_step_flashinfer(
...
@@ -270,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
...
@@ -303,4 +324,4 @@ void advance_step_flashinfer(
...
@@ -303,4 +324,4 @@ void advance_step_flashinfer(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
}
\ No newline at end of file
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
4d3a2c28
...
@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
...
@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
scale_type
const
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
}
}
}
}
...
@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
...
@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
}
}
...
@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
...
@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
float
const
zero
=
0.0
f
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
float
val
=
static_cast
<
float
>
(
input
[
i
]);
val
=
val
>
zero
?
val
:
-
val
;
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
}
...
@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
...
@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
*
tmp_scale
);
}
}
}
}
...
@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
...
@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int
const
token_idx
=
blockIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
val
=
static_cast
<
float
>
(
input
[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
}
...
@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
...
@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
4d3a2c28
...
@@ -8,6 +8,10 @@
...
@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using
namespace
vllm
;
/*
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
...
@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
return
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm75_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
...
@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
...
@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm75_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
}
...
@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
return
cutlass_gemm_sm80_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
...
@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
...
@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm80_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
}
...
@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
return
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
return
cutlass_gemm_sm89_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
else
{
}
else
{
...
@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
...
@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
return
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
vllm
::
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
return
cutlass_gemm_sm89_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
}
...
@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
...
@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogue
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
...
@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm89_epilogue
<
c2x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
4d3a2c28
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
#include "common.hpp"
// clang-format on
// clang-format on
...
@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
...
@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
#endif
#endif
}
}
};
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrZeroLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrZeroBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
// it would technically work but no use case as data_ptr is never nullptr
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static_assert
(
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
return
Arguments
{
data_ptr
};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
protected:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzp
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzpToken
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
template
<
typename
Arch
,
template
<
typename
>
typename
ArchGuard
,
typename
ElementAB_
,
typename
ElementD_
,
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
template
<
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
4d3a2c28
...
@@ -23,11 +23,12 @@
...
@@ -23,11 +23,12 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "
broadcast_load
_epilogue_c3x.hpp"
#include "
cutlass_extensions/epilogue/scaled_mm
_epilogue
s
_c3x.hpp"
#include "common.hpp"
#include "common.hpp"
// clang-format on
// clang-format on
using
namespace
cute
;
using
namespace
cute
;
using
namespace
vllm
;
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
...
@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
...
@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
#endif
#endif
}
}
};
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
&&
!
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
||
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
...
@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBias
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
b_scales
);
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
...
@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
...
@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzpToken
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzp
>
(
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
4d3a2c28
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined
CUDA_VERSION && CUDA_VERSION >= 12000
#if defined
ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
@@ -114,26 +114,41 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -114,26 +114,41 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Hopper
// Guard against compilation issues for sm90 kernels
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
return
;
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
#endif
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
return
;
}
if
(
version_num
>=
75
)
{
// Turing
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
@@ -174,25 +189,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
...
@@ -174,25 +189,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype "
,
c
.
dtype
());
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#
if
defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#else
return
;
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
#endif
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
{
return
;
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
}
\ No newline at end of file
csrc/quantization/fp8/common.cu
View file @
4d3a2c28
#include <ATen/cuda/CUDAContext.h>
#include "common.cuh"
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#else
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/hipcub.hpp>
#endif
#endif
#ifndef USE_ROCM
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
namespace
vllm
{
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template
<
typename
scalar_t
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t
tmp
=
0.0
;
while
(
i
<
num_elems
)
{
float
x
=
static_cast
<
float
>
(
input
[
i
]);
tmp
=
max
(
tmp
,
fabs
(
x
));
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
cache
[
threadIdx
.
x
]
=
tmp
;
__syncthreads
();
// Now perform parallel reduction within the thread block
int
ib
=
blockDim
.
x
/
2
;
while
(
ib
!=
0
)
{
if
(
threadIdx
.
x
<
ib
&&
cache
[
threadIdx
.
x
+
ib
]
>
cache
[
threadIdx
.
x
])
{
cache
[
threadIdx
.
x
]
=
cache
[
threadIdx
.
x
+
ib
];
}
__syncthreads
();
ib
/=
2
;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
template
<
typename
scalar_t
>
struct
__align__
(
8
)
vec4_t
{
scalar_t
x
;
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
typedef
struct
__align__
(
4
)
{
FP8_TYPE
x
;
FP8_TYPE
y
;
FP8_TYPE
z
;
FP8_TYPE
w
;
}
float8x4_t
;
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
...
@@ -204,8 +35,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -204,8 +35,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
// Use int64 to avoid overflowing an int32 when calculating this offset
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
// aligned at 8-byte and 4-byte addresses respectively.
...
...
csrc/quantization/fp8/common.cuh
0 → 100644
View file @
4d3a2c28
#pragma once
#include "quantization/vectorization.cuh"
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
constexpr
static
auto
kFp8Type
=
c10
::
CppTypeToScalarType
<
FP8_TYPE
>::
value
;
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
}
else
{
x
=
val
/
scale
;
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template
<
typename
scalar_t
>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t
tmp
=
0.0
;
while
(
i
<
num_elems
)
{
float
x
=
static_cast
<
float
>
(
input
[
i
]);
tmp
=
max
(
tmp
,
fabs
(
x
));
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
cache
[
threadIdx
.
x
]
=
tmp
;
__syncthreads
();
// Now perform parallel reduction within the thread block
int
ib
=
blockDim
.
x
/
2
;
while
(
ib
!=
0
)
{
if
(
threadIdx
.
x
<
ib
&&
cache
[
threadIdx
.
x
+
ib
]
>
cache
[
threadIdx
.
x
])
{
cache
[
threadIdx
.
x
]
=
cache
[
threadIdx
.
x
+
ib
];
}
__syncthreads
();
ib
/=
2
;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
template
<
typename
scalar_t
>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
return
absmax_val
;
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
using
float8x4_t
=
q8x4_t
<
FP8_TYPE
>
;
// Vectorized input/output to better utilize memory bandwidth.
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
}
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/fp8/fp8_marlin.cu
View file @
4d3a2c28
...
@@ -22,6 +22,8 @@
...
@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using
namespace
marlin
;
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
#endif
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
0 → 100644
View file @
4d3a2c28
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
has_residual
=
false
>
__device__
void
rms_norm_dynamic_per_token_quant_vec
(
scalar_out_t
*
__restrict__
out
,
// [..., hidden_size]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute rms
vllm
::
vectorized
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
// Compute scale
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
hidden_size
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
1.0
f
/
token_scale
,
hidden_size
,
residual
);
}
else
{
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
token_scale
,
hidden_size
,
residual
);
}
}
// RMS norm + quant kernel
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
has_residual
=
false
>
__global__
void
rms_norm_dynamic_per_token_quant_kernel
(
scalar_out_t
*
__restrict__
out
,
// [..., hidden_size]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
if
(
can_vectorize
)
{
return
rms_norm_dynamic_per_token_quant_vec
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
out
,
scales
,
input
,
weight
,
scale_ub
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
residual
);
}
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute RMS
vllm
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
// Compute Scale
vllm
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
hidden_size
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
1.0
f
/
token_scale
,
hidden_size
,
residual
);
}
else
{
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
token_scale
,
hidden_size
,
residual
);
}
}
}
// namespace vllm
// Residual add + RMS norm + dynamic per token
template
<
typename
scalar_in_t
>
void
rms_norm_dynamic_per_token_quant_dispatch
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
,
std
::
optional
<
at
::
Tensor
>&
residual
)
{
int32_t
hidden_size
=
input
.
size
(
-
1
);
int32_t
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
float
min_scaling_factor
=
out
.
dtype
()
==
torch
::
kInt8
?
std
::
numeric_limits
<
float
>::
epsilon
()
:
1.0
f
/
(
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
()
*
512.
f
);
if
(
residual
.
has_value
())
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
vllm
::
rms_norm_dynamic_per_token_quant_kernel
<
scalar_in_t
,
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
residual
->
data_ptr
<
scalar_in_t
>
());
});
}
else
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
vllm
::
rms_norm_dynamic_per_token_quant_kernel
<
scalar_in_t
,
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
nullptr
);
});
}
}
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scales
,
// [num_tokens]
double
const
var_epsilon
,
// Variance epsilon used in norm calculation
std
::
optional
<
at
::
Tensor
>
scale_ub
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
if
(
scale_ub
.
has_value
())
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
);
}
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_dispatch"
,
[
&
]
{
rms_norm_dynamic_per_token_quant_dispatch
<
scalar_t
>
(
out
,
input
,
weight
,
scales
,
var_epsilon
,
scale_ub
,
residual
);
});
}
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment