Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47f0954a
Unverified
Commit
47f0954a
authored
Jul 03, 2024
by
Michael Goin
Committed by
GitHub
Jul 03, 2024
Browse files
[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
parent
7cd2ebb0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1587 additions
and
44 deletions
+1587
-44
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+1308
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
docs/source/quantization/fp8.rst
docs/source/quantization/fp8.rst
+2
-1
docs/source/quantization/supported_hardware.rst
docs/source/quantization/supported_hardware.rst
+1
-1
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+84
-4
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+14
-5
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+134
-30
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+25
-3
No files found.
CMakeLists.txt
View file @
47f0954a
...
@@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
...
...
csrc/ops.h
View file @
47f0954a
...
@@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int64_t
num_bits
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/fp8/fp8_marlin.cu
0 → 100644
View file @
47f0954a
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "../gptq_marlin/gptq_marlin.cuh"
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
using
namespace
gptq_marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
fp8_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
FP16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
typename
ScalarType
<
half
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit
<
nv_bfloat16
>
(
int
q
)
{
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
BF16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr
uint32_t
BIAS
=
(
BIAS_OFFSET
+
127
)
<<
23
;
const
nv_bfloat162
bias_reg
=
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to bfloat162 and apply bias
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
),
bias_reg
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
constexpr
int
pack_factor
=
32
/
num_bits
;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We scale a `half2` tile in row-major layout for column-wise quantization.
int
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
thread_block_reduce
();
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
)
{
int
tb_n
=
th_config
.
thread_n
;
// Get max scale groups per thread-block
// Fixed for channelwise
int
tb_groups
=
1
;
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
-
1
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
int
*
locks
=
(
int
*
)
workspace
;
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
}
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace fp8_marlin
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
8
,
"num_bits must be 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
2
,
"b_scales rank = "
,
b_rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Channelwise only for FP8
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
1
)
num_groups
=
b_scales
.
size
(
0
);
// Verify workspace size
TORCH_CHECK
(
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
fp8_marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
fp8_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
{
TORCH_CHECK
(
false
,
"fp8_marlin_gemm only supports bfloat16 and float16"
);
}
return
c
;
}
#endif
csrc/torch_bindings.cpp
View file @
47f0954a
...
@@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
// quantization.
ops
.
def
(
ops
.
def
(
...
...
docs/source/quantization/fp8.rst
View file @
47f0954a
...
@@ -4,7 +4,8 @@ FP8
...
@@ -4,7 +4,8 @@ FP8
==================
==================
vLLM
supports
FP8
(
8
-
bit
floating
point
)
weight
and
activation
quantization
using
hardware
acceleration
on
GPUs
such
as
Nvidia
H100
and
AMD
MI300x
.
vLLM
supports
FP8
(
8
-
bit
floating
point
)
weight
and
activation
quantization
using
hardware
acceleration
on
GPUs
such
as
Nvidia
H100
and
AMD
MI300x
.
Currently
,
only
Hopper
and
Ada
Lovelace
GPUs
are
supported
.
Currently
,
only
Hopper
and
Ada
Lovelace
GPUs
are
officially
supported
for
W8A8
.
Ampere
GPUs
are
supported
for
W8A16
(
weight
-
only
FP8
)
utilizing
Marlin
kernels
.
Quantization
of
models
with
FP8
allows
for
a
2
x
reduction
in
model
memory
requirements
and
up
to
a
1.6
x
improvement
in
throughput
with
minimal
impact
on
accuracy
.
Quantization
of
models
with
FP8
allows
for
a
2
x
reduction
in
model
memory
requirements
and
up
to
a
1.6
x
improvement
in
throughput
with
minimal
impact
on
accuracy
.
Please
visit
the
HF
collection
of
`
quantized
FP8
checkpoints
of
popular
LLMs
ready
to
use
with
vLLM
<
https
://
huggingface
.
co
/
collections
/
neuralmagic
/
fp8
-
llms
-
for
-
vllm
-
666742
ed2b78b7ac8df13127
>`
_
.
Please
visit
the
HF
collection
of
`
quantized
FP8
checkpoints
of
popular
LLMs
ready
to
use
with
vLLM
<
https
://
huggingface
.
co
/
collections
/
neuralmagic
/
fp8
-
llms
-
for
-
vllm
-
666742
ed2b78b7ac8df13127
>`
_
.
...
...
docs/source/quantization/supported_hardware.rst
View file @
47f0954a
...
@@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
...
@@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌
❌
✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌
✅
✅ ✅ ❌ ❌ ❌ ❌ ❌
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
...
...
tests/kernels/test_marlin_gemm.py
View file @
47f0954a
...
@@ -8,7 +8,8 @@ import torch
...
@@ -8,7 +8,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm
)
marlin_perm
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
marlin_quantize
,
marlin_weights
)
marlin_quantize
,
marlin_weights
,
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
gptq_pack
,
quantize_weights
,
sort_weights
)
...
@@ -38,9 +39,11 @@ MNK_FACTORS = [
...
@@ -38,9 +39,11 @@ MNK_FACTORS = [
(
67
,
13
,
11
),
(
67
,
13
,
11
),
]
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
rand_data
(
shape
):
return
torch
.
randn
(
shape
,
dtype
=
torch
.
half
,
device
=
"cuda"
)
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
...
@@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
...
@@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print
(
"max_diff = {}"
.
format
(
max_diff
))
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
tests/quantization/test_fp8.py
View file @
47f0954a
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
import
torch
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
.
_custom_ops
import
scaled_fp8_quant
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
MODELS
=
[
MODELS
=
[
...
@@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
...
@@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
>=
89
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
...
@@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
...
@@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
x
=
(
torch
.
randn
(
size
=
(
11
,
11
),
device
=
"cuda"
)
*
13
).
to
(
dtype
)
x
=
(
torch
.
randn
(
size
=
(
11
,
11
),
device
=
"cuda"
)
*
13
).
to
(
dtype
)
# Dynamic quantization
# Dynamic quantization
ref_y
,
inv_scale
=
scaled_fp8_quant
(
x
,
None
)
ref_y
,
inv_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
)
ref_y
=
per_tensor_dequantize
(
ref_y
,
inv_scale
,
dtype
)
ref_y
=
per_tensor_dequantize
(
ref_y
,
inv_scale
,
dtype
)
# Reference dynamic quantizaton
# Reference dynamic quantizaton
...
@@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
...
@@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Static quantization
# Static quantization
y
,
_
=
scaled_fp8_quant
(
x
,
inv_scale
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
)
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Padding
# Padding
y
,
_
=
scaled_fp8_quant
(
x
,
inv_scale
,
batch_dim_padding
=
17
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
batch_dim_padding
=
17
)
assert
y
.
shape
[
0
]
==
17
assert
y
.
shape
[
0
]
==
17
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref_y
,
ref_y
,
...
...
vllm/_custom_ops.py
View file @
47f0954a
...
@@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k
,
is_k_full
)
size_k
,
is_k_full
)
# fp8 marlin
def
fp8_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
fp8_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
# fp8
# fp8
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
47f0954a
...
@@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...
@@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
pack_fp8_to_int32
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
...
@@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
8
9
return
8
0
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
@@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
def
_create_scale_param
(
def
_create_scale_param
(
self
,
self
,
scale_name
:
str
,
scale_name
:
str
,
...
@@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
process_after_load
=
True
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
...
@@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes
=
output_partition_sizes
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
**
extra_weight_attrs
)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if
self
.
use_marlin
:
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
prepare_layer_for_marlin
(
self
,
layer
:
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
assert
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
READY
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
layer
.
weight_scale
=
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
(
not
hasattr
(
layer
,
"process_after_load"
)
if
(
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
):
or
not
layer
.
process_after_load
):
...
@@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
return
return
# If checkpoint is fp8, requantize the separately quantized logical
# If checkpoint is fp8, requantize the separately quantized logical
...
@@ -233,14 +309,42 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -233,14 +309,42 @@ class Fp8LinearMethod(LinearMethodBase):
raise
ValueError
(
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
layer
.
weight
,
b_scales
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
else
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# ops.scaled_fp8_quant supports both dynamic and static quant.
#
If dynamic, layer.input_scale is None and x_scale computed from x
.
#
If dynamic, layer.input_scale is None and x_scale computed from x
#
If static, layer.input_scale is scalar and x_scale is input_scale
.
#
If static, layer.input_scale is scalar and x_scale is input_scale
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
47f0954a
...
@@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor
,
quantize_weights
,
sort_weights
)
get_pack_factor
,
quantize_weights
,
sort_weights
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
__cuda_arch
=
current_platform
.
get_device_capability
()
MARLIN_TILE
=
16
MARLIN_TILE
=
16
def
is_marlin_supported
():
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
capability
=
current_platform
.
get_device_capability
()
return
capability
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
...
@@ -223,3 +222,26 @@ class MarlinWorkspace:
...
@@ -223,3 +222,26 @@ class MarlinWorkspace:
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
"cuda"
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment