Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
272 additions
and
272 deletions
+272
-272
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-15
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-3
csrc/ops.h
csrc/ops.h
+27
-88
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+14
-3
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+28
-14
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+51
-25
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+4
-2
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+6
-0
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+24
-37
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+7
-1
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+27
-41
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+8
-2
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+13
-10
csrc/quantization/machete/machete_prepack_kernel.cuh
csrc/quantization/machete/machete_prepack_kernel.cuh
+3
-4
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+2
-2
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+13
-5
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+5
-0
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
+5
-0
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+5
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+28
-20
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
2c7f740a
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
csrc/moe/torch_bindings.cpp
View file @
6d2051cc
#include "core/registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
...
...
@@ -13,12 +12,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, "
"
b_zeros, Tensor!
g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
// conditionally compiled so impl registration is in source file
#endif
}
...
...
csrc/ops.h
View file @
6d2051cc
...
...
@@ -153,63 +153,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
namespace
machete
{
std
::
vector
<
std
::
string
>
supported_schedules
(
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
);
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
};
// namespace machete
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
...
...
@@ -219,11 +164,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -238,14 +178,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -278,26 +210,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
std
::
vector
<
torch
::
Tensor
>
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
seq_idx_
,
const
c10
::
optional
<
at
::
Tensor
>&
initial_states_
,
const
c10
::
optional
<
at
::
Tensor
>&
final_states_out_
,
bool
silu_activation
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
...
...
csrc/prepare_inputs/advance_step.cu
View file @
6d2051cc
...
...
@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
...
...
@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
&
t
,
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
const
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
...
...
@@ -211,7 +222,7 @@ void advance_step_flashinfer(
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %
d
\n
"
,
block_tables
.
stride
(
0
));
printf
(
" block_tables.stride(0) = %
zu
\n
"
,
block_tables
.
stride
(
0
));
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
...
...
@@ -303,4 +314,4 @@ void advance_step_flashinfer(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
\ No newline at end of file
}
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
6d2051cc
...
...
@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
}
}
...
...
@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
...
...
@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
float
val
=
static_cast
<
float
>
(
input
[
i
]);
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
...
...
@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
*
tmp_scale
);
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
}
}
...
...
@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int
const
token_idx
=
blockIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
val
=
static_cast
<
float
>
(
input
[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
...
...
@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
out
[
i
]
=
quant_val
;
}
}
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
6d2051cc
...
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined
CUDA_VERSION && CUDA_VERSION >= 12000
#if defined
ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#
if
defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#else
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
\ No newline at end of file
csrc/quantization/fp8/common.cu
View file @
6d2051cc
...
...
@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
...
...
csrc/quantization/fp8/fp8_marlin.cu
View file @
6d2051cc
...
...
@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
View file @
6d2051cc
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
#include "core/registration.h"
namespace
marlin
{
...
...
@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
}
uint32_t
vals
[
8
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
...
...
@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
...
...
@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
...
...
@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
...
...
@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
...
...
@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return
out
;
}
#endif
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
)
{
...
...
@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"awq_marlin_repack"
,
&
awq_marlin_repack
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
Meta
,
m
)
{
m
.
impl
(
"awq_marlin_repack"
,
&
awq_marlin_repack_meta
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
6d2051cc
...
...
@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
...
...
@@ -2258,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_
scale
s
.
size
(
1
),
"b_zeros dim 1 = "
,
b_
zero
s
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
...
...
@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
6d2051cc
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
gptq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
#include "core/registration.h"
namespace
marlin
{
...
...
@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
...
...
@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
...
...
@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
...
...
@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
...
...
@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
...
...
@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
...
...
@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return
out
;
}
#endif
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
)
{
...
...
@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
Meta
,
m
)
{
m
.
impl
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack_meta
);
}
\ No newline at end of file
csrc/quantization/machete/generate.py
View file @
6d2051cc
...
...
@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template
=
create_template
(
PREPACK_TEMPLATE
)
def
create_sources
(
impl_config
:
ImplConfig
,
num_impl_files
=
2
):
def
create_sources
(
impl_config
:
ImplConfig
,
num_impl_files
=
1
):
sources
=
[]
type_name
=
generate_type_signature
(
impl_config
.
type_config
)
...
...
@@ -457,7 +457,13 @@ def generate():
)),
]
schedules
=
list
(
set
([
x
[
1
]
for
x
in
default_heuristic
]))
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules
=
[]
for
_
,
schedule_config
in
default_heuristic
:
if
schedule_config
not
in
schedules
:
schedules
.
append
(
schedule_config
)
impl_configs
=
[]
...
...
csrc/quantization/machete/machete_mainloop.cuh
View file @
6d2051cc
...
...
@@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
tma_load_b
=
make_tma_copy_B
(
make_logical_tensor
(
ptr_B
,
make_shape
(
N
,
K
,
L
),
args
.
dB
));
int32_t
scale_k
=
(
ModeHasScales
)
?
(
K
+
args
.
group_size
-
1
)
/
args
.
group_size
:
0
;
int32_t
group_size
=
(
ModeHasScales
)
?
args
.
group_size
:
0
;
if
constexpr
(
ModeHasScales
)
{
tma_load_scale
=
make_tma_copy_scale
(
make_logical_tensor
(
args
.
ptr_S
,
make_shape
(
M
,
args
.
group_size
,
L
),
args
.
dS
));
tma_load_scale
=
make_tma_copy_scale
(
make_logical_tensor
(
args
.
ptr_S
,
make_shape
(
M
,
scale_k
,
L
),
args
.
dS
));
}
if
constexpr
(
KernelConversionMode
==
ConversionMode
::
ConvertAndScaleWithZero
)
{
tma_load_zero
=
make_tma_copy_zero
(
make_logical_tensor
(
args
.
ptr_Z
,
make_shape
(
M
,
args
.
group_size
,
L
),
args
.
dS
));
tma_load_zero
=
make_tma_copy_zero
(
make_logical_tensor
(
args
.
ptr_Z
,
make_shape
(
M
,
scale_k
,
L
),
args
.
dS
));
}
if
constexpr
(
KernelConversionMode
==
ConversionMode
::
DirectConvert
)
{
return
{
tma_load_a
,
tma_load_b
,
tma_load_scale
,
tma_load_zero
,
0
,
0
};
}
else
if
constexpr
(
ModeHasScales
)
{
auto
scale_k
=
(
K
+
args
.
group_size
-
1
)
/
args
.
group_size
;
if
constexpr
(
KernelConversionMode
==
ConversionMode
::
DirectConvert
||
KernelConversionMode
==
ConversionMode
::
ConvertAndScale
||
KernelConversionMode
==
ConversionMode
::
ConvertAndScaleWithZero
)
{
return
{
tma_load_a
,
tma_load_b
,
tma_load_scale
,
tma_load_zero
,
scale_k
,
args
.
group_size
};
tma_load_zero
,
scale_k
,
group_size
};
}
else
{
static_assert
(
cutlass
::
detail
::
dependent_false
<
KernelSchedule
>
,
"Conversion mode not handled in to_underlying_arguments."
);
...
...
csrc/quantization/machete/machete_prepack_kernel.cuh
View file @
6d2051cc
...
...
@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
}
template
<
typename
PrepackedLayoutB
,
typename
InLayout
>
static
void
prepack_B
(
cudaStream_t
stream
,
typename
PrepackedLayoutB
::
ElementB
const
*
B_in_ptr
,
InLayout
B_layout
,
typename
PrepackedLayoutB
::
ElementB
*
B_out_ptr
)
{
static
void
prepack_B_template
(
cudaStream_t
stream
,
typename
PrepackedLayoutB
::
ElementB
const
*
B_in_ptr
,
InLayout
B_layout
,
typename
PrepackedLayoutB
::
ElementB
*
B_out_ptr
)
{
using
TileShapeNKL
=
decltype
(
append
(
typename
PrepackedLayoutB
::
PPBlockShape_NK
{},
_1
{}));
auto
ilvd_NKbNbKL_to_offset
=
...
...
csrc/quantization/machete/machete_prepack_launcher.cuh
View file @
6d2051cc
...
...
@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// Allocate output
torch
::
Tensor
D
=
torch
::
empty_like
(
B
,
{},
at
::
MemoryFormat
::
Contiguous
);
prepack_B
<
PrepackedLayoutB
>
(
stream
,
B_ptr
,
layout_Bt
,
static_cast
<
ElementB
*>
(
D
.
mutable_data_ptr
()));
prepack_B
_template
<
PrepackedLayoutB
>
(
stream
,
B_ptr
,
layout_Bt
,
static_cast
<
ElementB
*>
(
D
.
mutable_data_ptr
()));
return
D
;
};
...
...
csrc/quantization/machete/machete_pytorch.cu
View file @
6d2051cc
...
...
@@ -2,6 +2,8 @@
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace
machete
{
using
namespace
vllm
;
...
...
@@ -78,14 +80,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
}
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
ScalarTypeTorchPtr
const
&
btype
)
{
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm
::
ScalarTypeTorchPtr
const
&
btype
)
{
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
PrepackBDispatcher
<
half_t
,
decltype
(
BType
),
half_t
>::
dispatch
(
B
);
});
#else
TORCH_CHECK
(
false
,
"Machete requires CUDA 12.0 or later"
);
#endif
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"machete_prepack_B"
,
&
prepack_B
);
m
.
impl
(
"machete_gemm"
,
&
gemm
);
}
// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL
(
TORCH_EXTENSION_NAME
,
CatchAll
,
m
)
{
m
.
impl
(
"machete_supported_schedules"
,
&
supported_schedules
);
}
};
// namespace machete
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
6d2051cc
...
...
@@ -26,6 +26,7 @@
#include <iostream>
#include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
...
...
@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm"
,
&
marlin_gemm
);
}
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
View file @
6d2051cc
...
...
@@ -30,6 +30,7 @@
#include <iostream>
#include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
...
...
@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return
d
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_qqq_gemm"
,
&
marlin_qqq_gemm
);
}
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
6d2051cc
...
...
@@ -28,6 +28,7 @@
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
);
}
csrc/torch_bindings.cpp
View file @
6d2051cc
...
...
@@ -260,7 +260,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"
);
ops
.
impl
(
"marlin_gemm"
,
torch
::
kCUDA
,
&
marlin_gemm
);
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
...
...
@@ -268,22 +268,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_24_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_24_gemm
);
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops
.
def
(
"machete_supported_schedules"
,
&
machete
::
supported_schedules
);
ops
.
def
(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]"
);
ops
.
def
(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor"
);
ops
.
impl
(
"machete_gemm"
,
torch
::
kCUDA
,
&
machete
::
gemm
);
ops
.
def
(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor"
);
ops
.
impl
(
"machete_prepack_B"
,
torch
::
kCUDA
,
&
machete
::
prepack_B
);
// conditionally compiled so impl registration is in source file
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
ops
.
impl
(
"permute_cols"
,
torch
::
kCUDA
,
&
permute_cols
);
...
...
@@ -295,21 +297,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_gemm
);
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
ops
.
def
(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kMeta
,
&
gptq_marlin_repack_meta
);
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops
.
def
(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"
);
ops
.
impl
(
"awq_marlin_repack"
,
torch
::
kCUDA
,
&
awq_marlin_repack
);
ops
.
impl
(
"awq_marlin_repack"
,
torch
::
kMeta
,
&
awq_marlin_repack_meta
);
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"
);
...
...
@@ -330,7 +330,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor"
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops
.
def
(
...
...
@@ -338,7 +338,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor"
);
ops
.
impl
(
"marlin_qqq_gemm"
,
torch
::
kCUDA
,
&
marlin_qqq_gemm
);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
...
...
@@ -366,27 +366,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"Tensor? D_, Tensor
!
? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor!? x) -> Tensor[]"
);
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
ops
.
def
(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias,"
"Tensor? bias
_
,"
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor"
);
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
def
(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor"
);
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
#endif
...
...
Prev
1
2
3
4
5
6
7
8
9
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment