Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4068f4b5
Unverified
Commit
4068f4b5
authored
Jan 04, 2025
by
Lu Fang
Committed by
GitHub
Jan 05, 2025
Browse files
[MISC] Replace c10::optional with std::optional (#11730)
Signed-off-by:
Lu Fang
<
lufang@fb.com
>
parent
47831430
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
130 additions
and
130 deletions
+130
-130
csrc/attention/paged_attention_v1.cu
csrc/attention/paged_attention_v1.cu
+2
-2
csrc/attention/paged_attention_v2.cu
csrc/attention/paged_attention_v2.cu
+2
-2
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+4
-4
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+5
-5
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+3
-3
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+3
-3
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+3
-3
csrc/cutlass_extensions/torch_utils.hpp
csrc/cutlass_extensions/torch_utils.hpp
+1
-1
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+12
-12
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+11
-11
csrc/ops.h
csrc/ops.h
+23
-23
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+2
-2
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+9
-9
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+3
-3
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+15
-15
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+1
-1
csrc/quantization/machete/machete_mm_kernel.cuh
csrc/quantization/machete/machete_mm_kernel.cuh
+5
-5
csrc/quantization/machete/machete_mm_launcher.cuh
csrc/quantization/machete/machete_mm_launcher.cuh
+12
-12
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+1
-1
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+13
-13
No files found.
csrc/attention/paged_attention_v1.cu
View file @
4068f4b5
...
...
@@ -53,7 +53,7 @@ void paged_attention_v1_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
...
...
@@ -176,7 +176,7 @@ void paged_attention_v1(
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
csrc/attention/paged_attention_v2.cu
View file @
4068f4b5
...
...
@@ -54,7 +54,7 @@ void paged_attention_v2_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
...
...
@@ -187,7 +187,7 @@ void paged_attention_v2(
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
csrc/cpu/attention.cpp
View file @
4068f4b5
...
...
@@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -459,7 +459,7 @@ void paged_attention_v1(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
@@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
max_seq_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -781,7 +781,7 @@ void paged_attention_v2(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
csrc/cpu/quant.cpp
View file @
4068f4b5
...
...
@@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
const
torch
::
Tensor
&
b
,
// [IC, OC], column-major
const
torch
::
Tensor
&
a_scales
,
// [1] or [M]
const
torch
::
Tensor
&
b_scales
,
// [1] or [OC]
const
c10
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
const
std
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
)
{
CPU_KERNEL_GUARD_IN
(
cutlass_scaled_mm
)
// Checks for conformality
...
...
@@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
const
torch
::
Tensor
&
a_scales
,
// [1] or [M]
const
torch
::
Tensor
&
b_scales
,
// [1] or [OC]
const
torch
::
Tensor
&
azp_adj
,
// [OC]
const
c10
::
optional
<
torch
::
Tensor
>&
azp
,
// [1] or [M]
const
c10
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
const
std
::
optional
<
torch
::
Tensor
>&
azp
,
// [1] or [M]
const
std
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
)
{
CPU_KERNEL_GUARD_IN
(
cutlass_scaled_mm_azp
)
// Checks for conformality
...
...
@@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
const
torch
::
Tensor
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
CPU_KERNEL_GUARD_IN
(
static_scaled_int8_quant
)
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
...
...
@@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
torch
::
Tensor
&
out
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scale
,
// [..., 1]
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_scaled_int8_quant
)
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
...
...
csrc/cpu/torch_bindings.cpp
View file @
4068f4b5
...
...
@@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
void
int8_scaled_mm
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
void
int8_scaled_mm_azp
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
azp_adj
,
const
c10
::
optional
<
torch
::
Tensor
>&
azp
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
std
::
optional
<
torch
::
Tensor
>&
azp
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
4068f4b5
...
...
@@ -68,7 +68,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static
auto
args_from_tensor
(
std
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static_assert
(
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
...
...
@@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
...
...
@@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
4068f4b5
...
...
@@ -67,7 +67,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static
auto
args_from_tensor
(
std
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
||
...
...
@@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
...
...
@@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
...
...
csrc/cutlass_extensions/torch_utils.hpp
View file @
4068f4b5
...
...
@@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template
<
typename
Stride
>
static
inline
auto
maybe_make_cute_layout
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
,
std
::
optional
<
torch
::
Tensor
>
const
&
tensor
,
std
::
string_view
name
=
"tensor"
)
{
using
Layout
=
decltype
(
make_cute_layout
<
Stride
>
(
*
tensor
));
...
...
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
4068f4b5
...
...
@@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
const
at
::
Tensor
x
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
out
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
...
...
@@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
4068f4b5
...
...
@@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
torch
::
Tensor
out
,
const
torch
::
Tensor
z
,
const
torch
::
Tensor
out_z
,
const
c10
::
optional
<
at
::
Tensor
>&
D
,
const
c10
::
optional
<
at
::
Tensor
>&
delta_bias
,
const
std
::
optional
<
at
::
Tensor
>&
D
,
const
std
::
optional
<
at
::
Tensor
>&
delta_bias
,
const
torch
::
Tensor
ssm_states
,
bool
has_z
,
bool
delta_softplus
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
,
int64_t
pad_slot_id
)
{
...
...
@@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>
&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>
&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>
&
delta_bias_
,
const
std
::
optional
<
torch
::
Tensor
>
&
D_
,
const
std
::
optional
<
torch
::
Tensor
>
&
z_
,
const
std
::
optional
<
torch
::
Tensor
>
&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
std
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
std
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
std
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
...
...
csrc/ops.h
View file @
4068f4b5
...
...
@@ -33,7 +33,7 @@ void paged_attention_v1(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
@@ -44,7 +44,7 @@ void paged_attention_v2(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
...
...
@@ -153,15 +153,15 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
);
...
...
@@ -169,7 +169,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_compress_entry
(
torch
::
Tensor
&
a_compressed
,
torch
::
Tensor
&
e
,
torch
::
Tensor
const
&
a
);
...
...
@@ -177,11 +177,11 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
...
...
@@ -198,34 +198,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void
dynamic_per_token_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
const
std
::
optional
<
torch
::
Tensor
>&
D_
,
const
std
::
optional
<
torch
::
Tensor
>&
z_
,
const
std
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
const
std
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
4068f4b5
...
...
@@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
...
...
@@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
torch
::
Tensor
&
scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scales
.
is_contiguous
());
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
4068f4b5
...
...
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
4068f4b5
...
...
@@ -51,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -70,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
4068f4b5
...
...
@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
...
...
@@ -84,7 +84,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
...
@@ -148,8 +148,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
...
csrc/quantization/machete/generate.py
View file @
4068f4b5
...
...
@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
static inline std::optional<at::ScalarType> maybe_scalartype(
c10
::optional<at::Tensor> const& t) {
std
::optional<at::Tensor> const& t) {
if (!t) {
return std::nullopt;
} else {
...
...
csrc/quantization/machete/machete_mm_kernel.cuh
View file @
4068f4b5
...
...
@@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
torch
::
Tensor
const
&
A
,
// MxK matrix
torch
::
Tensor
const
&
B
,
// KxN prepacked matrix
torch
::
Tensor
&
D
,
// MxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
c10
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
std
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
{
static_assert
(
!
with_group_zeropoints
||
with_group_scales
);
...
...
csrc/quantization/machete/machete_mm_launcher.cuh
View file @
4068f4b5
...
...
@@ -13,23 +13,23 @@ struct MMArgs {
torch
::
Tensor
const
&
A
;
torch
::
Tensor
const
&
B
;
vllm
::
ScalarType
const
&
b_type
;
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
c10
::
optional
<
int64_t
>
maybe_group_size
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
c10
::
optional
<
std
::
string
>
maybe_schedule
;
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
std
::
optional
<
int64_t
>
maybe_group_size
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
std
::
optional
<
std
::
string
>
maybe_schedule
;
};
struct
SupportedSchedulesArgs
{
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
};
torch
::
Tensor
mm_dispatch
(
MMArgs
args
);
...
...
csrc/quantization/machete/machete_prepack_launcher.cuh
View file @
4068f4b5
...
...
@@ -10,7 +10,7 @@ struct PrepackBArgs {
torch
::
Tensor
const
&
B
;
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
};
template
<
typename
PrepackedLayoutB
>
...
...
csrc/quantization/machete/machete_pytorch.cu
View file @
4068f4b5
...
...
@@ -10,11 +10,11 @@ using namespace vllm;
std
::
vector
<
std
::
string
>
supported_schedules
(
at
::
ScalarType
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
supported_schedules_dispatch
({
.
a_type
=
a_type
,
...
...
@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
c10
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
c10
::
optional
<
std
::
string
>
maybe_schedule
)
{
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
std
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
std
::
optional
<
std
::
string
>
maybe_schedule
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
mm_dispatch
({.
A
=
A
,
.
B
=
B
,
...
...
@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
at
::
ScalarType
const
&
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
prepack_B_dispatch
(
{.
B
=
B
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment