Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
396d92d5
Unverified
Commit
396d92d5
authored
Jul 21, 2024
by
Alexander Matveev
Committed by
GitHub
Jul 21, 2024
Browse files
[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
parent
25e778aa
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1453 additions
and
283 deletions
+1453
-283
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+8
-4
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+15
-18
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+269
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+425
-100
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+27
-33
csrc/quantization/gptq_marlin/marlin.cuh
csrc/quantization/gptq_marlin/marlin.cuh
+13
-2
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
+5
-3
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+24
-22
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+129
-12
tests/quantization/test_configs.py
tests/quantization/test_configs.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+15
-7
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+268
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+10
-6
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+21
-12
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+173
-51
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+41
-10
No files found.
CMakeLists.txt
View file @
396d92d5
...
@@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
...
...
csrc/ops.h
View file @
396d92d5
...
@@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int64_t
size_k
);
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_k
,
bool
is_k_full
);
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
...
...
csrc/quantization/fp8/fp8_marlin.cu
View file @
396d92d5
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin
* Adapted from https://github.com/IST-DASLab/marlin
*/
*/
#include "../gptq_marlin/
gptq_
marlin.cuh"
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/
gptq_
marlin_dtypes.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
using
namespace
gptq_
marlin
;
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
...
@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = "
,
size_k
);
", size_k = "
,
size_k
);
// Verify B
// Verify B
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_
marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_
marlin
::
tile_size
);
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_
marlin
::
tile_size
==
0
,
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
", actual_size_n = "
,
actual_size_n
);
...
@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups
=
b_scales
.
size
(
0
);
num_groups
=
b_scales
.
size
(
0
);
// Verify workspace size
// Verify workspace size
TORCH_CHECK
(
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
" is below min_workspace_size = "
,
min_workspace_size
);
...
@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_
marlin
::
max_par
);
marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
fp8_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
fp8_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_
marlin
::
max_par
);
marlin
::
max_par
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"fp8_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"fp8_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
0 → 100644
View file @
396d92d5
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
tile_n_ints
=
tile_n_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_ints
/
4
;
constexpr
int
stage_k_threads
=
tile_k_size
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int
first_n_packed
=
first_n
/
pack_factor
;
int4
*
sh_ptr
=
sh
+
stage_size
*
pipe
;
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k
+
k_id
)
*
(
size_n
/
pack_factor
)
+
first_n_packed
+
(
n_id
*
4
)])));
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
int
cur_n_packed
=
cur_n
/
pack_factor
;
int
cur_n_pos
=
cur_n
%
pack_factor
;
constexpr
int
sh_stride
=
tile_n_ints
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
// Undo interleaving
int
cur_n_pos_unpacked
;
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
undo_pack
[
8
]
=
{
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
else
{
constexpr
int
undo_pack
[
4
]
=
{
0
,
2
,
1
,
3
};
cur_n_pos_unpacked
=
undo_pack
[
cur_n_pos
];
}
uint32_t
vals
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
packed_src_0
=
sh_stage_int_ptr
[
cur_n_packed
+
sh_stride
*
cur_elem
];
int
packed_src_1
=
sh_stage_int_ptr
[
cur_n_packed
+
(
8
/
pack_factor
)
+
sh_stride
*
cur_elem
];
vals
[
i
]
=
(
packed_src_0
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
packed_src_1
>>
(
cur_n_pos_unpacked
*
num_bits
))
&
mask
;
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
(
b_q_weight
.
size
(
0
)
==
size_k
,
"b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
" is not size_k = "
,
size_k
);
TORCH_CHECK
((
size_n
/
pack_factor
)
==
b_q_weight
.
size
(
1
),
"Shape mismatch: b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
", size_n = "
,
size_n
,
", pack_factor = "
,
pack_factor
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
)
CALL_IF
(
8
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
);
}
return
out
;
}
#endif
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
396d92d5
...
@@ -19,8 +19,8 @@
...
@@ -19,8 +19,8 @@
* Adapted from https://github.com/IST-DASLab/marlin
* Adapted from https://github.com/IST-DASLab/marlin
*/
*/
#include "
gptq_
marlin.cuh"
#include "marlin.cuh"
#include "
gptq_
marlin_dtypes.cuh"
#include "marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
...
@@ -32,7 +32,7 @@ inline std::string str(T x) {
...
@@ -32,7 +32,7 @@ inline std::string str(T x) {
return
std
::
to_string
(
x
);
return
std
::
to_string
(
x
);
}
}
namespace
gptq_
marlin
{
namespace
marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
@@ -72,10 +72,11 @@ __global__ void Marlin(
...
@@ -72,10 +72,11 @@ __global__ void Marlin(
}
// namespace gptq_marlin
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_k
,
bool
is_k_full
)
{
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
return
torch
::
empty
({
1
,
1
});
...
@@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
...
@@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return
frag_b
;
return
frag_b
;
}
}
// Zero-point dequantizers
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit_zp
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit_zp
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit_zp
<
half
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit_zp
<
nv_bfloat16
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388608.
f
;
fp32_intermediates
[
1
]
-=
8388608.
f
;
fp32_intermediates
[
2
]
-=
8388608.
f
;
fp32_intermediates
[
3
]
-=
8388608.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
// only for grouped quantization.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
...
@@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
}
template
<
typename
scalar_t
>
__device__
inline
void
sub_zp
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
scalar_t2
&
frag_zp
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
zp
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
...
@@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
...
@@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const
int
stages
,
// number of stages for the async global->shared
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
// with a separate quantization scale
>
>
...
@@ -413,6 +534,8 @@ __global__ void Marlin(
...
@@ -413,6 +534,8 @@ __global__ void Marlin(
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
...
@@ -437,6 +560,7 @@ __global__ void Marlin(
...
@@ -437,6 +560,7 @@ __global__ void Marlin(
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
constexpr
int
pack_factor
=
32
/
num_bits
;
constexpr
int
pack_factor
=
32
/
num_bits
;
...
@@ -566,6 +690,13 @@ __global__ void Marlin(
...
@@ -566,6 +690,13 @@ __global__ void Marlin(
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
@@ -605,6 +736,19 @@ __global__ void Marlin(
...
@@ -605,6 +736,19 @@ __global__ void Marlin(
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
// row-major in the latter case.
...
@@ -616,6 +760,18 @@ __global__ void Marlin(
...
@@ -616,6 +760,18 @@ __global__ void Marlin(
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
// Precompute which thread should not read memory in which iterations; this is
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
// when the batchsize is not a multiple of 16.
...
@@ -664,7 +820,8 @@ __global__ void Marlin(
...
@@ -664,7 +820,8 @@ __global__ void Marlin(
int4
*
sh_a
=
sh
;
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// Register storage for double buffer of shared memory reads.
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
FragA
frag_a
[
2
][
thread_m_blocks
];
...
@@ -672,6 +829,8 @@ __global__ void Marlin(
...
@@ -672,6 +829,8 @@ __global__ void Marlin(
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
// Zero accumulators.
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
auto
zero_accums
=
[
&
]()
{
...
@@ -777,6 +936,28 @@ __global__ void Marlin(
...
@@ -777,6 +936,28 @@ __global__ void Marlin(
}
}
}
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// Insert a fence even when we are winding down the pipeline to ensure that
...
@@ -784,6 +965,12 @@ __global__ void Marlin(
...
@@ -784,6 +965,12 @@ __global__ void Marlin(
cp_async_fence
();
cp_async_fence
();
};
};
auto
fetch_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// We only have `stages - 2` active fetches since we are double buffering
...
@@ -932,8 +1119,73 @@ __global__ void Marlin(
...
@@ -932,8 +1119,73 @@ __global__ void Marlin(
}
}
};
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
if
constexpr
(
!
has_zp
)
{
return
;
}
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
if
constexpr
(
num_bits
==
4
)
{
int
zp_quant
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_shift
=
zp_quant
>>
8
;
frag_zp_0
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant
);
frag_zp_1
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant_shift
);
}
else
{
int
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
frag_zp_0
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_0
);
frag_zp_1
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_1
);
}
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
// dequantization and matmul operations.
#pragma unroll
#pragma unroll
...
@@ -944,17 +1196,33 @@ __global__ void Marlin(
...
@@ -944,17 +1196,33 @@ __global__ void Marlin(
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
int
b_quant_shift
=
b_quant
>>
8
;
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant_shift
);
}
else
{
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
}
else
{
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_1
);
}
else
{
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
}
}
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
}
// Apply scale to frag_b0
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
...
@@ -967,6 +1235,11 @@ __global__ void Marlin(
...
@@ -967,6 +1235,11 @@ __global__ void Marlin(
}
}
}
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b1
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
...
@@ -1189,6 +1462,12 @@ __global__ void Marlin(
...
@@ -1189,6 +1462,12 @@ __global__ void Marlin(
}
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
}
if
constexpr
(
has_zp
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
}
...
@@ -1197,6 +1476,7 @@ __global__ void Marlin(
...
@@ -1197,6 +1476,7 @@ __global__ void Marlin(
init_same_group
(
0
);
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
};
...
@@ -1217,6 +1497,7 @@ __global__ void Marlin(
...
@@ -1217,6 +1497,7 @@ __global__ void Marlin(
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
slice_iters
>=
stages
);
...
@@ -1354,6 +1635,7 @@ __global__ void Marlin(
...
@@ -1354,6 +1635,7 @@ __global__ void Marlin(
}
else
{
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
start_pipes
();
start_pipes
();
...
@@ -1363,22 +1645,24 @@ __global__ void Marlin(
...
@@ -1363,22 +1645,24 @@ __global__ void Marlin(
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER &&
group_blocks == GROUP_BLOCKS &&
\
has_act_order == HAS_ACT_ORDER &&
has_zp == HAS_ZP &&
\
num_threads == NUM_THREADS) {
\
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {
\
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>,
\
HAS_ZP,
GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
HAS_ZP, GROUP_BLOCKS> \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
prob_k, locks); \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
prob_m, prob_n, prob_k, locks); \
}
}
typedef
struct
{
typedef
struct
{
...
@@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS,
1
, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false,
false,
-1, NUM_THREADS) \
__CALL_IF(NUM_BITS,
1
, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false,
false,
2, NUM_THREADS) \
__CALL_IF(NUM_BITS,
1
, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false,
false,
4, NUM_THREADS) \
__CALL_IF(NUM_BITS,
1
, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false,
false,
8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS,
2
, N_BLOCKS, K_BLOCKS, false,
true,
-1, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS,
2
, N_BLOCKS, K_BLOCKS, false,
true,
2, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS,
2
, N_BLOCKS, K_BLOCKS, false,
true,
4, NUM_THREADS) \
__CALL_IF(NUM_BITS,
3
, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS,
2
, N_BLOCKS, K_BLOCKS, false,
true,
8, NUM_THREADS) \
\
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
num_groups
,
int
group_size
,
int
dev
,
int
thread_n
,
int
sms
,
int
max_par
)
{
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
...
@@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
...
@@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
}
// Define kernel configurations
if
(
false
)
{
if
(
false
)
{
}
}
CALL_IF
(
4
,
32
,
2
,
256
)
GPTQ_CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
8
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
32
,
2
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
8
,
4
,
8
,
128
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
AWQ_CALL_IF
(
4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
8
,
4
,
8
,
128
)
else
{
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
,
num_groups
,
", group_size = "
,
group_size
,
", num_groups = "
+
str
(
num_groups
)
+
", thread_m_blocks = "
,
thread_m_blocks
,
", group_size = "
+
str
(
group_size
)
+
", thread_n_blocks = "
,
thread_n_blocks
,
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_k_blocks = "
,
thread_k_blocks
,
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", num_bits = "
,
num_bits
);
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
...
@@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
...
@@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
}
// namespace gptq_marlin
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_k
,
bool
is_k_full
)
{
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
// Verify num_bits
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = "
,
size_k
);
", size_k = "
,
size_k
);
// Verify B
// Verify B
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_
marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_
marlin
::
tile_size
);
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_
marlin
::
tile_size
==
0
,
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
", actual_size_n = "
,
actual_size_n
);
...
@@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_zeros
.
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
...
@@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
group_size
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
b_
rank
=
b_scales
.
sizes
().
size
();
int
rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_
rank
==
2
,
"b_scales rank = "
,
b_
rank
,
" is not 2"
);
TORCH_CHECK
(
rank
==
2
,
"b_scales rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
num_groups
=
b_scales
.
size
(
0
);
...
@@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_zeros rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_zeros
.
size
(
0
)
==
num_groups
,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
// Verify workspace size
// Verify workspace size
TORCH_CHECK
(
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
gptq_
marlin
::
marlin_mm_f16i4
<
half
>
(
marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_n
,
sms
,
gptq_
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
gptq_
marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
is_k_full
,
num_groups
,
group_size
,
dev
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
gptq_
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
396d92d5
#include "gptq_marlin.cuh"
#include "marlin.cuh"
namespace
gptq_marlin
{
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
int
size_k
,
int
size_n
)
{}
}
// namespace
gptq_
marlin
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
...
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else
#else
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
__global__
void
gptq_
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
int
size_k
,
int
size_n
)
{
...
@@ -259,17 +254,17 @@ __global__ void marlin_repack_kernel(
...
@@ -259,17 +254,17 @@ __global__ void marlin_repack_kernel(
}
}
}
}
}
// namespace
gptq_
marlin
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads,
\
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads,
NUM_BITS,
\
NUM_BITS,
HAS_PERM>, \
HAS_PERM>,
\
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_
marlin::marlin_repack_kernel<
gptq_
marlin::repack_threads, NUM_BITS, \
marlin::
gptq_
marlin_repack_kernel<marlin::repack_threads, NUM_BITS,
\
HAS_PERM> \
HAS_PERM> \
<<<blocks,
gptq_
marlin::repack_threads, max_shared_mem, stream>>>( \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(
\
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
}
...
@@ -277,10 +272,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -277,10 +272,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_
marlin
::
tile_k_size
);
" is not divisible by tile_k_size = "
,
marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
%
marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_
marlin
::
tile_n_size
);
" is not divisible by tile_n_size = "
,
marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -308,9 +303,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -308,9 +303,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto
options
=
torch
::
TensorOptions
()
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
Tensor
out
=
torch
::
empty
(
torch
::
empty
({
size_k
/
gptq_marlin
::
tile_size
,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
options
);
// Detect if there is act_order
// Detect if there is act_order
...
...
csrc/quantization/gptq_marlin/
gptq_
marlin.cuh
→
csrc/quantization/gptq_marlin/marlin.cuh
View file @
396d92d5
...
@@ -9,7 +9,9 @@
...
@@ -9,7 +9,9 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <iostream>
#include <iostream>
namespace
gptq_marlin
{
namespace
marlin
{
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// than 1 warp per schedule allows some more latency hiding. At the same time,
...
@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
...
@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
max_par
=
16
;
// Repack params
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
// Helpers
template
<
typename
T
,
int
n
>
template
<
typename
T
,
int
n
>
struct
Vec
{
struct
Vec
{
T
elems
[
n
];
T
elems
[
n
];
...
@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
...
@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
#endif
#endif
}
// namespace
gptq_
marlin
}
// namespace marlin
csrc/quantization/gptq_marlin/
gptq_
marlin_dtypes.cuh
→
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
View file @
396d92d5
#ifndef _data_types_cuh
#ifndef _data_types_cuh
#define _data_types_cuh
#define _data_types_cuh
#include "
gptq_
marlin.cuh"
#include "marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
namespace
gptq_
marlin
{
namespace
marlin
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
class
ScalarType
{};
class
ScalarType
{};
...
@@ -23,6 +23,7 @@ class ScalarType<half> {
...
@@ -23,6 +23,7 @@ class ScalarType<half> {
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
using
FragZP
=
Vec
<
half2
,
4
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
return
__half2float
(
x
);
...
@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
...
@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
using
FragZP
=
Vec
<
nv_bfloat162
,
4
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
...
@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
...
@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
#endif
#endif
};
};
}
// namespace
gptq_
marlin
}
// namespace marlin
#endif
#endif
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
396d92d5
...
@@ -30,7 +30,7 @@ inline std::string str(T x) {
...
@@ -30,7 +30,7 @@ inline std::string str(T x) {
return
std
::
to_string
(
x
);
return
std
::
to_string
(
x
);
}
}
namespace
marlin
{
namespace
marlin
_dense
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
...
@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
...
@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
}
}
}
}
}
// namespace marlin
}
// namespace marlin
_dense
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
...
@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
size_k
==
a
.
size
(
1
),
TORCH_CHECK
(
size_k
==
a
.
size
(
1
),
"Shape mismatch: a.size(1) = "
+
str
(
a
.
size
(
1
))
+
"Shape mismatch: a.size(1) = "
+
str
(
a
.
size
(
1
))
+
", size_k = "
+
str
(
size_k
));
", size_k = "
+
str
(
size_k
));
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
TORCH_CHECK
(
size_k
%
marlin
_dense
::
tile_size
==
0
,
"size_k = "
+
str
(
size_k
)
+
"size_k = "
+
str
(
size_k
)
+
" is not divisible by tile_size = "
+
" is not divisible by tile_size = "
+
str
(
marlin
::
tile_size
));
str
(
marlin
_dense
::
tile_size
));
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
marlin
_dense
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
+
"Shape mismatch: b_q_weight.size(0) = "
+
str
(
b_q_weight
.
size
(
0
))
+
", size_k = "
+
str
(
size_k
)
+
str
(
b_q_weight
.
size
(
0
))
+
", size_k = "
+
str
(
size_k
)
+
", tile_size = "
+
str
(
marlin
::
tile_size
));
", tile_size = "
+
str
(
marlin
_dense
::
tile_size
));
// Verify N
// Verify N
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales.size(1) = "
+
str
(
b_scales
.
size
(
1
))
+
"b_scales.size(1) = "
+
str
(
b_scales
.
size
(
1
))
+
", size_n = "
+
str
(
size_n
));
", size_n = "
+
str
(
size_n
));
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin_dense
::
tile_size
==
0
,
"b_q_weight.size(1) = "
+
str
(
b_q_weight
.
size
(
1
))
+
"b_q_weight.size(1) = "
+
str
(
b_q_weight
.
size
(
1
))
+
" is not divisible by tile_size = "
+
str
(
marlin
::
tile_size
));
" is not divisible by tile_size = "
+
str
(
marlin
_dense
::
tile_size
));
int
actual_size_n
=
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin_dense
::
tile_size
)
*
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
marlin
::
pack_factor_4bit
;
marlin_dense
::
pack_factor_4bit
;
TORCH_CHECK
(
TORCH_CHECK
(
size_n
==
actual_size_n
,
size_n
==
actual_size_n
,
"size_n = "
+
str
(
size_n
)
+
", actual_size_n = "
+
str
(
actual_size_n
));
"size_n = "
+
str
(
size_n
)
+
", actual_size_n = "
+
str
(
actual_size_n
));
...
@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"Unexpected groupsize = "
+
str
(
groupsize
));
"Unexpected groupsize = "
+
str
(
groupsize
));
// Verify workspace size
// Verify workspace size
TORCH_CHECK
(
TORCH_CHECK
(
size_n
%
marlin_dense
::
min_thread_n
==
0
,
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
+
str
(
size_n
)
+
"size_n = "
+
str
(
size_n
)
+
", is not divisible by min_thread_n = "
+
str
(
marlin
::
min_thread_n
));
", is not divisible by min_thread_n = "
+
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
str
(
marlin_dense
::
min_thread_n
));
int
min_workspace_size
=
(
size_n
/
marlin_dense
::
min_thread_n
)
*
marlin_dense
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
+
str
(
workspace
.
numel
())
+
"workspace.numel = "
+
str
(
workspace
.
numel
())
+
" is below min_workspace_size = "
+
str
(
min_workspace_size
));
" is below min_workspace_size = "
+
str
(
min_workspace_size
));
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
marlin
::
marlin_cuda
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
marlin
_dense
::
marlin_cuda
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
b_scales
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
groupsize
,
dev
,
workspace
.
data_ptr
(),
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
sms
,
marlin
::
max_par
);
thread_n
,
sms
,
marlin
_dense
::
max_par
);
return
c
;
return
c
;
}
}
csrc/torch_bindings.cpp
View file @
396d92d5
...
@@ -141,6 +141,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -141,6 +141,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
// awq_marlin repack from AWQ.
ops
.
def
(
"awq_marlin_repack"
,
&
awq_marlin_repack
);
ops
.
impl
(
"awq_marlin_repack"
,
torch
::
kCUDA
,
&
awq_marlin_repack
);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
ops
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
...
...
tests/kernels/test_marlin_gemm.py
View file @
396d92d5
...
@@ -12,16 +12,18 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
...
@@ -12,16 +12,18 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_
MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_
MARLIN_SUPPORTED_NUM_BITS
,
MARLIN_SUPPORTED_GROUP_SIZES
,
MARLIN_SUPPORTED_NUM_BITS
,
marlin_permute_scales
)
marlin_make_empty_g_idx
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack_fp8_to_int32
)
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
awq_pack
,
gptq_pack
,
quantize_weights
,
quantize_weights_with_zp
,
sort_weights
)
ACT_ORDER_OPTS
=
[
False
,
True
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
...
@@ -57,11 +59,11 @@ def rand_data(shape, dtype=torch.float16):
...
@@ -57,11 +59,11 @@ def rand_data(shape, dtype=torch.float16):
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
act_order
,
def
test_
gptq_
marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
act_order
,
mnk_factors
):
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
@@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
...
@@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Create input
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
b_weight
,
num_bits
,
group_size
)
# Pack to AWQ format
q_w_awq
=
awq_pack
(
q_w
,
num_bits
,
size_k
,
size_n
)
# Pack to Marlin format
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
awq_marlin_repack
(
q_w_awq
,
size_k
,
size_n
,
num_bits
,
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
def
test_marlin_gemm
(
def
test_
gptq_
marlin_gemm
(
k_chunk
,
k_chunk
,
n_chunk
,
n_chunk
,
num_bits
,
num_bits
,
...
@@ -155,6 +205,8 @@ def test_marlin_gemm(
...
@@ -155,6 +205,8 @@ def test_marlin_gemm(
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
num_bits
,
group_size
,
act_order
)
b_weight
,
num_bits
,
group_size
,
act_order
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
GPTQ_MARLIN_MAX_PARALLEL
)
...
@@ -162,6 +214,7 @@ def test_marlin_gemm(
...
@@ -162,6 +214,7 @@ def test_marlin_gemm(
a_input
,
a_input
,
marlin_q_w
,
marlin_q_w
,
marlin_s
,
marlin_s
,
marlin_zp
,
g_idx
,
g_idx
,
sort_indices
,
sort_indices
,
workspace
.
scratch
,
workspace
.
scratch
,
...
@@ -170,6 +223,7 @@ def test_marlin_gemm(
...
@@ -170,6 +223,7 @@ def test_marlin_gemm(
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
is_k_full
,
has_zp
=
False
,
)
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
@@ -188,7 +242,8 @@ def test_marlin_gemm(
...
@@ -188,7 +242,8 @@ def test_marlin_gemm(
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_24_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
def
test_gptq_marlin_24_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_m
=
m_factor
...
@@ -301,3 +356,65 @@ def test_fp8_marlin_gemm(
...
@@ -301,3 +356,65 @@ def test_fp8_marlin_gemm(
print
(
"max_diff = {}"
.
format
(
max_diff
))
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
num_bits
,
group_size
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
has_zp
=
True
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
has_zp
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
tests/quantization/test_configs.py
View file @
396d92d5
...
@@ -44,9 +44,9 @@ MODEL_ARG_EXPTYPES = [
...
@@ -44,9 +44,9 @@ MODEL_ARG_EXPTYPES = [
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"awq"
,
"ERROR"
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"awq"
,
"ERROR"
),
# AUTOAWQ
# AUTOAWQ
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
None
,
"awq"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
None
,
"awq
_marlin
"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"awq"
,
"awq"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"awq"
,
"awq"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"marlin"
,
"
ERROR
"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"marlin"
,
"
awq_marlin
"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"gptq"
,
"ERROR"
),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
,
"gptq"
,
"ERROR"
),
]
]
...
...
vllm/_custom_ops.py
View file @
396d92d5
...
@@ -276,14 +276,22 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
...
@@ -276,14 +276,22 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits
)
num_bits
)
# gptq_marlin
def
awq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
has_zp
:
bool
)
->
torch
.
Tensor
:
workspace
,
num_bits
,
size_m
,
size_n
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
size_k
,
is_k_full
)
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
)
# fp8 marlin
# fp8 marlin
...
...
vllm/config.py
View file @
396d92d5
...
@@ -251,7 +251,7 @@ class ModelConfig:
...
@@ -251,7 +251,7 @@ class ModelConfig:
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
if
(
self
.
quantization
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
)):
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
)):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
396d92d5
...
@@ -2,6 +2,7 @@ from typing import Dict, Type
...
@@ -2,6 +2,7 @@ from typing import Dict, Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
...
@@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
0 → 100644
View file @
396d92d5
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_awq_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_awq_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
class
AWQMarlinConfig
(
QuantizationConfig
):
"""Config class for AWQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
lm_head_quantized
=
lm_head_quantized
verify_awq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
has_zp
=
self
.
has_zp
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"has_zp=
{
self
.
has_zp
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"awq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_awq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"awq"
:
logger
.
info
(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"AWQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
AWQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
has_zp
=
quant_config
.
get
(
"zero_point"
,
None
)
if
quant_method
!=
"awq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
has_zp
is
None
):
return
False
return
check_awq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
has_zp
=
has_zp
,
min_capability
=
cls
.
get_min_capability
())
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
num_groups
=
num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Repack weights from AWQ format to marlin format.
marlin_qweight
=
ops
.
awq_marlin_repack
(
layer
.
qweight
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
layer
.
qzeros
,
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_awq_marlin_linear
(
input
=
x
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
qzeros
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
396d92d5
...
@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
...
@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
apply_
gptq_
marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
marlin_permute_scales
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -38,7 +38,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -38,7 +38,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
group_size
=
group_size
self
.
group_size
=
group_size
# Verify supported on platform.
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
num_bits
,
verify_
gptq_
marlin_supported
(
num_bits
=
self
.
num_bits
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
is_sym
=
True
)
...
@@ -135,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -135,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from compressed-tensors format to marlin format.
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
(),
layer
.
weight_packed
.
t
().
contiguous
(),
...
@@ -155,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -155,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_marlin_linear
(
return
apply_
gptq_
marlin_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight_packed
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
396d92d5
...
@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
apply_
gptq_
marlin_linear
,
check_
gptq_
marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -37,7 +37,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -37,7 +37,7 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
lm_head_quantized
=
lm_head_quantized
self
.
lm_head_quantized
=
lm_head_quantized
# Verify supported on platform.
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
weight_bits
,
verify_
gptq_
marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
self
.
is_sym
)
is_sym
=
self
.
is_sym
)
...
@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
can_convert
=
cls
.
is_
gptq_
marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
...
@@ -105,19 +105,24 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -105,19 +105,24 @@ class GPTQMarlinConfig(QuantizationConfig):
return
[]
return
[]
@
classmethod
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_
gptq_
marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
if
quant_method
!=
"gptq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
or
desc_act
is
None
):
return
False
return
False
return
check_marlin_supported
(
num_bits
=
num_bits
,
return
check_gptq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
group_size
=
group_size
,
is_sym
=
sym
,
is_sym
=
sym
,
min_capability
=
cls
.
get_min_capability
())
min_capability
=
cls
.
get_min_capability
())
...
@@ -278,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -278,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
zp
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from autogptq format to marlin format.
# Repack weights from autogptq format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
qweight
,
...
@@ -302,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -302,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
apply_marlin_linear
(
return
apply_
gptq_
marlin_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
qweight
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
396d92d5
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.quant_utils
import
pack_cols
,
unpack_cols
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER
=
[
-
1
]
def
check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
# If the capability of the device is too low, cannot convert.
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
has_zp
:
bool
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
min_capability
is
not
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
device_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
if
device_capability
<
min_capability
:
return
False
return
(
False
,
"Marlin does not support device_capability = {}"
", the min_capability required is {}"
.
format
(
device_capability
,
min_capability
))
return
(
device_capability
>=
min_capability
if
num_bits
not
in
MARLIN_SUPPORTED_NUM_BITS
:
and
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
return
(
False
,
"Marlin does not support weight_bits = {}. "
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
"Only weight_bits = {} are supported."
.
format
(
and
is_sym
in
GPTQ_
MARLIN_SUPPORTED_
SYM
)
num_bits
,
MARLIN_SUPPORTED_
NUM_BITS
)
)
if
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
(
False
,
"Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported."
.
format
(
group_size
,
MARLIN_SUPPORTED_GROUP_SIZES
))
def
verify_marlin_supported
(
num_bits
:
int
,
group_size
:
Optional
[
int
],
if
not
has_zp
and
not
is_sym
:
is_sym
:
bool
)
->
None
:
return
(
False
,
"Marlin without zero_points must have symmetric quantization"
)
return
True
,
None
if
num_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
def
check_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
f
"Marlin does not support weight_bits =
{
num_bits
}
. "
min_capability
:
int
)
->
bool
:
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
cond
,
_
=
_check_marlin_supported
(
num_bits
,
"are supported."
)
group_size
,
if
(
group_size
is
None
is_sym
,
or
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
):
min_capability
,
raise
ValueError
(
has_zp
=
False
)
f
"Marlin does not support group_size =
{
group_size
}
. "
return
cond
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
def
check_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
raise
ValueError
(
min_capability
:
int
)
->
bool
:
f
"Marlin does not support is_sym = is_sym. "
cond
,
_
=
_check_marlin_supported
(
num_bits
,
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
group_size
,
False
,
min_capability
,
has_zp
=
has_zp
)
return
cond
def
verify_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
=
None
,
has_zp
=
False
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"GPTQ"
+
err_msg
)
def
verify_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
=
None
,
has_zp
=
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"AWQ"
+
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
...
@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
...
@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return
s
return
s
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm
,
_
=
get_scale_perms
()
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
return
zp
def
awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp
=
unpack_cols
(
q_zp_packed
,
num_bits
,
size_k
,
size_n
)
# Undo interleaving (use argsort(..) to get inverse perm)
if
num_bits
==
4
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]))
elif
num_bits
==
8
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
1
,
3
]))
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_zp
=
q_zp
.
reshape
((
-
1
,
len
(
undo_interleave
)))[:,
undo_interleave
].
ravel
()
q_zp
=
q_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
marlin_zp
=
marlin_zero_points
(
q_zp
,
size_k
,
size_n
,
num_bits
)
return
marlin_zp
# Newly generated tensors need to replace existing tensors that are
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
...
@@ -149,9 +232,11 @@ def replace_tensor(layer: torch.nn.Module, name: str,
...
@@ -149,9 +232,11 @@ def replace_tensor(layer: torch.nn.Module, name: str,
del
new_t
del
new_t
def
apply_marlin_linear
(
input
:
torch
.
Tensor
,
def
apply_gptq_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
...
@@ -166,6 +251,42 @@ def apply_marlin_linear(input: torch.Tensor,
...
@@ -166,6 +251,42 @@ def apply_marlin_linear(input: torch.Tensor,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight
,
weight_scale
,
weight_scale
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
apply_awq_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight_scale
,
weight_zp
,
g_idx
,
g_idx
,
g_idx_sort_indices
,
g_idx_sort_indices
,
workspace
,
workspace
,
...
@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
...
@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
size_m
=
reshaped_x
.
shape
[
0
],
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
)
is_k_full
=
True
,
has_zp
=
True
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
396d92d5
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
from
typing
import
List
from
typing
import
List
import
numpy
import
numpy
as
np
import
torch
import
torch
from
.marlin_utils
import
GPTQ_MARLIN_TILE
,
marlin_permute_scales
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
from
.quant_utils
import
get_pack_factor
,
quantize_weights
,
sort_weights
marlin_zero_points
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
quantize_weights_with_zp
,
sort_weights
)
class
MarlinWorkspace
:
class
MarlinWorkspace
:
...
@@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
...
@@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
pack_factor
=
get_pack_factor
(
num_bits
)
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
n
umpy
.
uint32
)
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
n
p
.
uint32
)
q_packed
=
n
umpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
q_packed
=
n
p
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
n
umpy
.
uint32
)
dtype
=
n
p
.
uint32
)
for
i
in
range
(
pack_factor
):
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
n
umpy
.
int32
)).
to
(
orig_device
)
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
n
p
.
int32
)).
to
(
orig_device
)
return
q_packed
return
q_packed
...
@@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int):
...
@@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int):
for
j
in
range
(
4
):
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
n
umpy
.
array
(
perm_list
)
perm
=
n
p
.
array
(
perm_list
)
if
num_bits
==
4
:
if
num_bits
==
4
:
interleave
=
n
umpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
interleave
=
n
p
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
elif
num_bits
==
8
:
interleave
=
n
umpy
.
array
([
0
,
2
,
1
,
3
])
interleave
=
n
p
.
array
([
0
,
2
,
1
,
3
])
else
:
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
...
@@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Detect num groups
assert
size_k
%
group_size
==
0
num_groups
=
size_k
//
group_size
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
w
,
num_bits
,
group_size
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
num_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment