Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2527 additions
and
1604 deletions
+2527
-1604
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+4
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+1
-1
csrc/moe/marlin_kernels/marlin_moe_kernel.h
csrc/moe/marlin_kernels/marlin_moe_kernel.h
+1425
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
+29
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
+20
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+29
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+18
-0
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+185
-1371
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+5
-2
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+5
-3
csrc/ops.h
csrc/ops.h
+10
-9
csrc/permute_cols.cu
csrc/permute_cols.cu
+88
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+160
-13
csrc/quantization/gguf/dequantize.cuh
csrc/quantization/gguf/dequantize.cuh
+46
-9
csrc/quantization/gguf/ggml-common.h
csrc/quantization/gguf/ggml-common.h
+277
-131
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+5
-0
csrc/quantization/gguf/mmvq.cuh
csrc/quantization/gguf/mmvq.cuh
+8
-0
csrc/quantization/gguf/vecdotq.cuh
csrc/quantization/gguf/vecdotq.cuh
+83
-18
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+127
-46
csrc/quantization/machete/machete_mm_kernel.cuh
csrc/quantization/machete/machete_mm_kernel.cuh
+2
-1
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.h
View file @
539aa992
...
...
@@ -36,6 +36,10 @@ struct ConvParamsBase {
void
*
__restrict__
conv_state_ptr
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
539aa992
...
...
@@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
u
.
scalar_type
(),
"selective_scan_fwd"
,
[
&
]
{
selective_scan_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
std
::
vector
<
at
::
Tensor
>
result
=
{
out
,
x
.
value
()
};
std
::
vector
<
at
::
Tensor
>
result
=
{
out
};
if
(
has_z
)
{
result
.
push_back
(
out_z
);
}
return
result
;
}
...
...
csrc/moe/marlin_kernels/marlin_moe_kernel.h
0 → 100644
View file @
539aa992
This diff is collapsed.
Click to expand it.
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
0 → 100644
View file @
539aa992
#include "marlin_moe_kernel_ku4b8.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
0 → 100644
View file @
539aa992
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
0 → 100644
View file @
539aa992
#include "marlin_moe_kernel_ku8b128.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
0 → 100644
View file @
539aa992
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
csrc/moe/marlin_moe_ops.cu
View file @
539aa992
This diff is collapsed.
Click to expand it.
csrc/moe/marlin_moe_ops.h
View file @
539aa992
...
...
@@ -2,11 +2,14 @@
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
csrc/moe/torch_bindings.cpp
View file @
539aa992
...
...
@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor"
);
"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
#endif
}
...
...
csrc/ops.h
View file @
539aa992
...
...
@@ -155,6 +155,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
};
// namespace machete
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
...
...
@@ -226,10 +228,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
// torch::Tensor b_gptq_qzeros,
...
...
@@ -262,11 +266,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
...
...
@@ -281,8 +284,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
...
...
csrc/permute_cols.cu
0 → 100644
View file @
539aa992
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
static
constexpr
int
default_threads
=
256
;
static
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
std
::
max
(
finish_row
-
start_row
,
0
);
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
default_threads
;
int
rest
=
size_k
%
default_threads
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
default_threads
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
auto
dev
=
A
.
get_device
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
TORCH_CHECK
(
A
.
scalar_type
()
==
at
::
kHalf
||
A
.
scalar_type
()
==
at
::
kBFloat16
,
"Currently only 16bit types are supported"
);
TORCH_CHECK
(
A
.
is_contiguous
(),
"A must be contiguous"
);
TORCH_CHECK
(
A
.
size
(
-
1
)
%
8
==
0
,
"A columns must be a multiple of 8 (128bits)"
);
auto
A_2d
=
A
.
view
({
-
1
,
A
.
size
(
-
1
)});
torch
::
Tensor
D
=
torch
::
empty_like
(
A
);
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
block_rows
=
div_ceil
(
A_2d
.
size
(
0
),
sms
);
permute_cols_kernel
<<<
sms
,
default_threads
,
0
,
stream
>>>
(
reinterpret_cast
<
int4
const
*>
(
A_2d
.
const_data_ptr
()),
perm
.
const_data_ptr
<
int
>
(),
reinterpret_cast
<
int4
*>
(
D
.
mutable_data_ptr
()),
A_2d
.
size
(
0
),
A_2d
.
size
(
1
),
block_rows
);
return
D
;
}
\ No newline at end of file
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
539aa992
...
...
@@ -14,12 +14,17 @@
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
static
const
float
i8_min
=
static
const
expr
auto
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
const
float
i8_max
=
static
const
expr
auto
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// round
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
...
...
@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif
}
static
inline
__device__
int32_t
float_to_int32_rn
(
float
x
)
{
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static
constexpr
auto
i32_min
=
std
::
numeric_limits
<
int32_t
>::
min
();
static
constexpr
auto
i32_min_f
=
static_cast
<
float
>
(
i32_min
);
static
constexpr
auto
i32_max
=
std
::
numeric_limits
<
int32_t
>::
max
();
static
constexpr
auto
i32_max_f
=
static_cast
<
float
>
(
i32_max
);
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate on the higher end.
if
(
dst
>=
i32_max_f
)
{
return
i32_max
;
}
// saturate on the lower end.
if
(
dst
<=
i32_min_f
)
{
return
i32_min
;
}
return
static_cast
<
int32_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.rni.sat.s32.f32 %0, %1;"
:
"=r"
(
dst
)
:
"f"
(
x
));
return
reinterpret_cast
<
const
int32_t
&>
(
dst
);
#endif
}
static
inline
__device__
int8_t
int32_to_int8
(
int32_t
x
)
{
#ifdef USE_ROCM
static
constexpr
auto
i8_min
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
constexpr
auto
i8_max
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.sat.s8.s32 %0, %1;"
:
"=r"
(
dst
)
:
"r"
(
x
));
return
reinterpret_cast
<
const
int8_t
&>
(
dst
);
#endif
}
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scale_type
>
...
...
@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
static_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
}
}
template
<
typename
scalar_t
,
typename
scale_type
>
__global__
void
dynamic_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
...
...
@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int
const
token_idx
=
blockIdx
.
x
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
// Reduce the max and min values across the block
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
max_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
max_val
,
cub
::
Max
{},
blockDim
.
x
);
__syncthreads
();
// Make sure min doesn't mess with max shared memory
min_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
min_val
,
cub
::
Min
{},
blockDim
.
x
);
__shared__
scale_type
scale_sh
;
__shared__
azp_type
azp_sh
;
// Compute the scale and zero point and store them, only on the first thread
if
(
threadIdx
.
x
==
0
)
{
float
const
scale_val
=
(
max_val
-
min_val
)
/
255.0
f
;
// Use rounding to even (same as torch.round)
auto
const
azp_float
=
std
::
nearbyint
(
-
128.0
f
-
min_val
/
scale_val
);
auto
const
azp_val
=
static_cast
<
azp_type
>
(
azp_float
);
// Store the scale and azp into shared and global
scale
[
token_idx
]
=
scale_sh
=
scale_val
;
azp
[
token_idx
]
=
azp_sh
=
azp_val
;
}
// Wait for the scale and azp to be computed
__syncthreads
();
float
const
scale_val
=
scale_sh
;
azp_type
const
azp_val
=
azp_sh
;
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
}
}
}
// namespace vllm
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
scale
)
{
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
!
azp
||
azp
->
numel
()
==
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
if
(
!
azp
)
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
static_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scales
)
{
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scales
.
is_contiguous
());
TORCH_CHECK
(
!
azp
||
azp
->
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
if
(
!
azp
)
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
dynamic_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
csrc/quantization/gguf/dequantize.cuh
View file @
539aa992
...
...
@@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
int
64_t
i
=
blockIdx
.
x
;
const
block_iq1_s
*
x
=
(
const
block_iq1_s
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
float
delta
=
x
[
i
].
qh
[
ib
]
&
0x8000
?
-
1
-
IQ1S_DELTA
:
-
1
+
IQ1S_DELTA
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
2
*
((
x
[
i
].
qh
[
ib
]
>>
12
)
&
7
)
+
1
);
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
ib
]
>>
3
*
il
)
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_m
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int64_t
i
=
blockIdx
.
x
;
const
block_iq1_m
*
x
=
(
const
block_iq1_m
*
)
vx
;
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
int
i8
=
4
*
ib
+
il
;
uint8_t
h
=
x
[
i
].
scales
[
i8
/
2
]
>>
4
*
(
i8
%
2
);
const
int8_t
*
grid
=
(
const
int8_t
*
)(
iq1s_grid
+
(
x
[
i
].
qs
[
i8
]
|
((
h
&
8
)
<<
5
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
2
*
(
h
&
7
)
+
1
);
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]);
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00f0
)
|
((
sc
[
2
]
>>
4
)
&
0x0f00
)
|
(
sc
[
3
]
&
0xf000
);
const
int64_t
ib16
=
2
*
ib
+
il
/
2
;
// sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const
float
d
=
__half2float
(
scale
.
f16
)
*
(
2
*
((
sc
[
ib16
/
4
]
>>
3
*
(
ib16
%
4
))
&
0x7
)
+
1
);
const
float
delta
=
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
&
(
0x08
<<
4
*
(
il
%
2
))
?
-
1
-
IQ1M_DELTA
:
-
1
+
IQ1M_DELTA
;
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
>>
4
*
(
il
%
2
))
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
));
}
}
template
<
typename
dst_t
>
...
...
@@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
dequantize_block_iq1_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq1_m_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq1_m
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq4_nl_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
(
k
+
QK_K
-
1
)
/
QK_K
;
...
...
@@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
return
dequantize_row_iq2_s_cuda
;
case
23
:
return
dequantize_row_iq4_xs_cuda
;
case
29
:
return
dequantize_row_iq1_m_cuda
;
default:
return
nullptr
;
}
...
...
csrc/quantization/gguf/ggml-common.h
View file @
539aa992
This diff is collapsed.
Click to expand it.
csrc/quantization/gguf/gguf_kernel.cu
View file @
539aa992
...
...
@@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
29
:
mul_mat_vec_iq1_m_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
}
return
Y
;
}
...
...
csrc/quantization/gguf/mmvq.cuh
View file @
539aa992
...
...
@@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half *
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI1_M
,
block_iq1_m
,
1
,
vec_dot_iq1_m_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
...
...
csrc/quantization/gguf/vecdotq.cuh
View file @
539aa992
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static
__device__
__forceinline__
int
get_int_b2
(
const
void
*
x
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
x
;
// assume at least 2 byte alignment
int
x32
=
x16
[
2
*
i32
+
0
]
<<
0
;
x32
|=
x16
[
2
*
i32
+
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_b4
(
const
void
*
x
,
const
int
&
i32
)
{
return
((
const
int
*
)
x
)[
i32
];
// assume at least 4 byte alignment
}
static
__device__
__forceinline__
int
get_int_from_int8
(
const
int8_t
*
x8
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
(
x8
+
sizeof
(
int
)
*
i32
);
// assume at least 2 byte alignment
int
x32
=
0
;
...
...
@@ -1661,24 +1674,76 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq1_s
*
bq1
=
(
const
block_iq1_s
*
)
vbq
;
const
int
ib32
=
iqs
;
int
sumi1
=
0
,
sumi2
=
0
,
sumi3
=
0
,
sumi4
=
0
;
const
uint8_t
h1
=
bq1
->
scales
[
2
*
ib32
+
0
];
const
uint8_t
h2
=
bq1
->
scales
[
2
*
ib32
+
1
];
const
int
*
q8
=
(
const
int
*
)
bq8_1
[
ib32
].
qs
;
const
int
*
grid1
=
(
const
int
*
)(
iq1s_grid
+
(
bq1
->
qs
[
4
*
ib32
+
0
]
|
((
h1
&
0x08
)
<<
5
)));
const
int
*
grid2
=
(
const
int
*
)(
iq1s_grid
+
(
bq1
->
qs
[
4
*
ib32
+
1
]
|
((
h1
&
0x80
)
<<
1
)));
const
int
*
grid3
=
(
const
int
*
)(
iq1s_grid
+
(
bq1
->
qs
[
4
*
ib32
+
2
]
|
((
h2
&
0x08
)
<<
5
)));
const
int
*
grid4
=
(
const
int
*
)(
iq1s_grid
+
(
bq1
->
qs
[
4
*
ib32
+
3
]
|
((
h2
&
0x80
)
<<
1
)));
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
sumi1
=
__dp4a
(
q8
[
j
+
0
],
grid1
[
j
],
sumi1
);
sumi2
=
__dp4a
(
q8
[
j
+
2
],
grid2
[
j
],
sumi2
);
sumi3
=
__dp4a
(
q8
[
j
+
4
],
grid3
[
j
],
sumi3
);
sumi4
=
__dp4a
(
q8
[
j
+
6
],
grid4
[
j
],
sumi4
);
}
const
float
d
=
__half2float
(
bq1
->
d
)
*
__low2float
(
bq8_1
[
ib32
].
ds
);
return
d
*
(
sumi1
*
(
2
*
(
h1
&
7
)
+
1
)
+
sumi2
*
(
2
*
((
h1
>>
4
)
&
7
)
+
1
)
+
sumi3
*
(
2
*
(
h2
&
7
)
+
1
)
+
sumi4
*
(
2
*
((
h2
>>
4
)
&
7
)
+
1
));
const
int
qs_packed
=
get_int_b2
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
const
int
qh
=
bq1
->
qh
[
iqs
];
int
sumi
=
0
;
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
(((
qh
>>
3
*
(
l0
/
2
))
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
=
__dp4a
(
grid0
,
u0
,
sumi
);
sumi
=
__dp4a
(
grid1
,
u1
,
sumi
);
}
const
float
d1q
=
__half2float
(
bq1
->
d
)
*
(((
qh
>>
11
)
&
0x0E
)
+
1
);
const
float
delta
=
-
1.0
f
+
IQ1S_DELTA
-
(
qh
&
0x8000
)
*
(
2.0
f
*
IQ1S_DELTA
/
0x8000
);
const
float2
ds
=
__half22float2
(
bq8_1
[
iqs
].
ds
);
return
d1q
*
(
ds
.
x
*
sumi
+
ds
.
y
*
delta
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq1_m_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq1_m
*
bq1
=
(
const
block_iq1_m
*
)
vbq
;
const
int
qs_packed
=
get_int_b4
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
int
sumi
[
2
]
=
{
0
};
float
sumf
[
2
]
=
{
0.0
f
};
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
qhl
=
bq1
->
qh
[
2
*
iqs
+
l0
/
4
]
>>
(
4
*
((
l0
/
2
)
%
2
));
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
((
qhl
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
[
l0
/
4
]
=
__dp4a
(
grid0
,
u0
,
sumi
[
l0
/
4
]);
sumi
[
l0
/
4
]
=
__dp4a
(
grid1
,
u1
,
sumi
[
l0
/
4
]);
const
float
delta
=
-
1.0
f
+
IQ1M_DELTA
-
(
qhl
&
0x08
)
*
(
2.0
f
*
IQ1M_DELTA
/
0x08
);
int
sumy
=
0
;
sumy
=
__dp4a
(
u0
,
0x01010101
,
sumy
);
sumy
=
__dp4a
(
u1
,
0x01010101
,
sumy
);
sumf
[
l0
/
4
]
+=
delta
*
sumy
;
}
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
bq1
->
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00F0
)
|
((
sc
[
2
]
>>
4
)
&
0x0F00
)
|
(
sc
[
3
]
&
0xF000
);
const
float
d
=
__half2float
(
scale
.
f16
)
*
__low2float
(
bq8_1
[
iqs
].
ds
);
const
int
tmp
=
sc
[
iqs
/
2
]
>>
(
6
*
(
iqs
%
2
));
const
int
sc0
=
2
*
((
tmp
>>
0
)
&
0x07
)
+
1
;
const
int
sc1
=
2
*
((
tmp
>>
3
)
&
0x07
)
+
1
;
return
d
*
((
sumi
[
0
]
+
sumf
[
0
])
*
sc0
+
(
sumi
[
1
]
+
sumf
[
1
])
*
sc1
);
#endif
}
...
...
csrc/quantization/machete/generate.py
View file @
539aa992
...
...
@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
@
dataclass
@
dataclass
(
frozen
=
True
)
class
ScheduleConfig
:
tile_shape_mn
:
Tuple
[
int
,
int
]
cluster_shape_mnk
:
Tuple
[
int
,
int
,
int
]
...
...
@@ -328,56 +328,137 @@ def generate():
# about how this works
SCRIPT_DIR
=
os
.
path
.
dirname
(
__file__
)
schedules
=
[
ScheduleConfig
(
tile_shape_mn
=
tile_shape_mn
,
cluster_shape_mnk
=
cluster_shape_mnk
,
kernel_schedule
=
kernel_schedule
,
epilogue_schedule
=
epilogue_schedule
,
tile_scheduler
=
tile_scheduler
,
)
for
tile_shape_mn
,
cluster_shape_mnk
in
(
((
128
,
16
),
(
1
,
1
,
1
)),
((
128
,
32
),
(
1
,
1
,
1
)),
((
128
,
64
),
(
1
,
1
,
1
)),
((
128
,
128
),
(
1
,
1
,
1
)),
)
for
kernel_schedule
in
(
TmaMI
,
)
for
epilogue_schedule
in
(
TmaCoop
,
)
for
tile_scheduler
in
(
TileSchedulerType
.
StreamK
,
)
]
schedule_common_params
=
dict
(
kernel_schedule
=
TmaMI
,
epilogue_schedule
=
TmaCoop
,
tile_scheduler
=
TileSchedulerType
.
StreamK
,
)
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic
=
[
(
"M > 64"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
128
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
kernel_schedule
=
TmaMI
,
epilogue_schedule
=
TmaCoop
,
tile_scheduler
=
TileSchedulerType
.
StreamK
,
)),
(
"M > 32"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
64
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
kernel_schedule
=
TmaMI
,
epilogue_schedule
=
TmaCoop
,
tile_scheduler
=
TileSchedulerType
.
StreamK
,
)),
(
"M > 16"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
32
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
kernel_schedule
=
TmaMI
,
epilogue_schedule
=
TmaCoop
,
tile_scheduler
=
TileSchedulerType
.
StreamK
,
)),
(
None
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
16
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
kernel_schedule
=
TmaMI
,
epilogue_schedule
=
TmaCoop
,
tile_scheduler
=
TileSchedulerType
.
StreamK
))
#### M = 257+
(
"M > 256 && K <= 16384 && N <= 4096"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
128
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 256"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
256
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
#### M = 129-256
(
"M > 128 && K <= 4096 && N <= 4096"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
64
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 128 && K <= 8192 && N <= 8192"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
128
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 128"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
256
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
#### M = 65-128
(
"M > 64 && K <= 4069 && N <= 4069"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
32
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 64 && K <= 4069 && N <= 8192"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
64
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 64 && K >= 8192 && N >= 12288"
,
ScheduleConfig
(
tile_shape_mn
=
(
256
,
128
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 64"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
128
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
#### M = 33-64
(
"M > 32 && K <= 6144 && N <= 6144"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
16
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 32 && K >= 16384 && N >= 12288"
,
ScheduleConfig
(
tile_shape_mn
=
(
256
,
64
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 32"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
64
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
#### M = 17-32
(
"M > 16 && K <= 12288 && N <= 8192"
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
32
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
"M > 16"
,
ScheduleConfig
(
tile_shape_mn
=
(
256
,
32
),
cluster_shape_mnk
=
(
2
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
#### M = 1-16
(
"N >= 26624"
,
ScheduleConfig
(
tile_shape_mn
=
(
256
,
16
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
(
None
,
ScheduleConfig
(
tile_shape_mn
=
(
128
,
16
),
cluster_shape_mnk
=
(
1
,
1
,
1
),
**
schedule_common_params
# type: ignore
)),
]
schedules
=
list
(
set
([
x
[
1
]
for
x
in
default_heuristic
]))
impl_configs
=
[]
GPTQ_kernel_type_configs
=
list
(
...
...
csrc/quantization/machete/machete_mm_kernel.cuh
View file @
539aa992
...
...
@@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
int
M
=
size
<
0
>
(
layout_A
),
N
=
size
<
1
>
(
layout_D
),
K
=
size
<
1
>
(
layout_A
);
int
const
group_size
=
maybe_group_size
.
value_or
(
K
);
int
const
group_size
=
maybe_group_size
==
-
1
?
K
:
maybe_group_size
.
value_or
(
K
);
int
const
scale_k
=
(
K
+
group_size
-
1
)
/
group_size
;
TORCH_CHECK
(
size
<
0
>
(
layout_A
)
==
M
&&
size
<
1
>
(
layout_A
)
==
K
);
...
...
Prev
1
2
3
4
5
6
7
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment