Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d0feea31
Unverified
Commit
d0feea31
authored
Mar 08, 2025
by
Jinzhen Lin
Committed by
GitHub
Mar 07, 2025
Browse files
[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
58abe354
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
99 additions
and
24 deletions
+99
-24
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+46
-16
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+10
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-1
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+32
-0
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
d0feea31
...
@@ -538,6 +538,7 @@ __global__ void Marlin(
...
@@ -538,6 +538,7 @@ __global__ void Marlin(
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
...
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
...
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh_red
[
c_sh_rd
];
if
(
use_atomic_add
&&
slice_count
>
1
)
{
scalar_t2
*
C_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
C
[
c_gl_wr
]);
scalar_t2
*
sh_red_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
sh_red
[
c_sh_rd
]);
#pragma unroll
for
(
int
a
=
0
;
a
<
4
;
a
++
)
{
atomicAdd
(
&
C_half2
[
a
],
sh_red_half2
[
a
]);
}
}
else
{
C
[
c_gl_wr
]
=
sh_red
[
c_sh_rd
];
}
c_gl_wr
+=
c_gl_wr_delta
;
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
...
@@ -1644,7 +1655,7 @@ __global__ void Marlin(
...
@@ -1644,7 +1655,7 @@ __global__ void Marlin(
}
}
cp_async_fence
();
cp_async_fence
();
}
else
{
}
else
{
if
(
last
)
{
if
(
last
||
use_atomic_add
)
{
if
(
s_sh_wr_pred
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
}
...
@@ -1664,7 +1675,7 @@ __global__ void Marlin(
...
@@ -1664,7 +1675,7 @@ __global__ void Marlin(
}
}
}
else
{
}
else
{
if
(
last
)
{
if
(
last
||
use_atomic_add
)
{
cp_async_wait
<
0
>
();
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
...
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
}
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
if
(
slice_count
>
1
&&
!
use_atomic_add
)
{
//
block in a slice
// only globally reduce if there is more than one
block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
if
(
use_fp32_reduce
)
{
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
global_reduce_fp32
(
slice_idx
==
0
,
last
);
...
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
...
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
}
}
barrier_release
(
&
locks
[
slice_col
],
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
}
if
(
last
)
// only the last block in a slice actually writes the result
if
(
last
||
use_atomic_add
)
// only the last block in a slice actuallywrites the result
write_result
();
write_result
();
slice_row
=
0
;
slice_row
=
0
;
slice_col_par
++
;
slice_col_par
++
;
...
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
...
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
use_fp32_reduce); \
} \
} \
}
}
...
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
...
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
int
sms
,
int
max_par
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
if
(
has_zp
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
...
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
is_k_full
,
bool
has_zp
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
if
(
has_zp
)
{
if
(
has_zp
)
{
...
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Alloc buffers
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
;
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
if
(
use_atomic_add
)
{
c
=
torch
::
zeros
({
size_m
,
size_n
},
options
);
}
else
{
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
}
torch
::
Tensor
a_tmp
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
if
(
has_act_order
)
{
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
}
else
{
a_tmp
=
torch
::
empty
({
0
},
options
);
}
// Alloc C tmp buffer that is going to be used for the global reduce
// Alloc C tmp buffer that is going to be used for the global reduce
torch
::
Tensor
c_tmp
;
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
int
reduce_n
=
size_n
;
auto
options_fp32
=
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
if
(
use_fp32_reduce
)
{
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
}
else
{
reduce_max_m
=
0
;
reduce_max_m
=
0
;
reduce_n
=
0
;
reduce_n
=
0
;
c_tmp
=
torch
::
empty
({
0
},
options_fp32
);
}
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
// auto -1)
...
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Detect groupsize and act_order
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
rank
=
b_scales
.
sizes
().
size
();
int
rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_scales rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
rank
==
2
,
"b_scales rank = "
,
rank
,
" is not 2"
);
...
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
,
is_zp_float
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm
<
nv_bfloat16
>
(
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
...
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
,
is_zp_float
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
csrc/torch_bindings.cpp
View file @
d0feea31
...
@@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor"
,
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor"
,
{
stride_tag
});
{
stride_tag
});
// conditionally compiled so impl registration is in source file
// conditionally compiled so impl registration is in source file
...
...
tests/kernels/test_marlin_gemm.py
View file @
d0feea31
...
@@ -34,6 +34,7 @@ from vllm.scalar_type import scalar_types
...
@@ -34,6 +34,7 @@ from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS
=
[
False
,
True
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
USE_ATOMIC_ADD_OPTS
=
[
False
,
True
]
USE_FP32_REDUCE_OPTS
=
[
False
,
True
]
USE_FP32_REDUCE_OPTS
=
[
False
,
True
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_K_CHUNKS
=
[
128
]
...
@@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...
@@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"use_atomic_add"
,
USE_ATOMIC_ADD_OPTS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_gptq_marlin_gemm
(
def
test_gptq_marlin_gemm
(
k_chunk
,
k_chunk
,
...
@@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
...
@@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
mnk_factors
,
mnk_factors
,
act_order
,
act_order
,
is_k_full
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
use_fp32_reduce
,
):
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
@@ -228,12 +231,12 @@ def test_gptq_marlin_gemm(
...
@@ -228,12 +231,12 @@ def test_gptq_marlin_gemm(
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_
id
x
,
sort_indices
,
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
]
,
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
]
,
b_weight
.
shape
[
1
]
,
a_input
.
shape
[
1
],
is_k_full
,
False
,
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
,
False
),
use_atomic_add
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
a_input
,
...
@@ -249,6 +252,7 @@ def test_gptq_marlin_gemm(
...
@@ -249,6 +252,7 @@ def test_gptq_marlin_gemm(
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
is_zp_float
=
False
,
)
)
...
...
vllm/_custom_ops.py
View file @
d0feea31
...
@@ -301,6 +301,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -301,6 +301,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_k
:
torch
.
SymInt
,
size_k
:
torch
.
SymInt
,
is_k_full
:
bool
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
has_zp
:
bool
=
False
,
use_atomic_add
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
...
@@ -713,12 +714,14 @@ def gptq_marlin_gemm(a: torch.Tensor,
...
@@ -713,12 +714,14 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
has_zp
:
bool
=
False
,
use_atomic_add
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
,
use_fp32_reduce
,
is_zp_float
)
has_zp
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
# fp8 marlin
# fp8 marlin
...
...
vllm/envs.py
View file @
d0feea31
...
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
...
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_SIZE
:
int
=
1
VLLM_DP_SIZE
:
int
=
1
VLLM_DP_MASTER_IP
:
str
=
""
VLLM_DP_MASTER_IP
:
str
=
""
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -630,6 +631,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -630,6 +631,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use S3 path for model loading in CI via RunAI Streamer
# Whether to use S3 path for model loading in CI via RunAI Streamer
"VLLM_CI_USE_S3"
:
"VLLM_CI_USE_S3"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
lambda
:
os
.
environ
.
get
(
"VLLM_CI_USE_S3"
,
"0"
)
==
"1"
,
# Whether to use atomicAdd reduce in gptq/awq marlin kernel.
"VLLM_MARLIN_USE_ATOMIC_ADD"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MARLIN_USE_ATOMIC_ADD"
,
"0"
)
==
"1"
,
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
d0feea31
...
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
...
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import
numpy
import
numpy
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
...
@@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return
output
return
output
def
should_use_atomic_add_reduce
(
m
:
int
,
n
:
int
,
k
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
bool
:
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if
not
envs
.
VLLM_MARLIN_USE_ATOMIC_ADD
or
device
.
type
!=
"cuda"
:
return
False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
return
False
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return
max
(
m
,
64
)
*
n
<
64
*
2048
and
k
>=
2048
def
apply_gptq_marlin_linear
(
def
apply_gptq_marlin_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
...
@@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
output_size_per_partition
,
k
=
reshaped_x
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight
,
weight_scale
,
weight_scale
,
...
@@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
...
@@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
is_zp_float
=
False
)
...
@@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
...
@@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
output_size_per_partition
,
k
=
reshaped_x
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight
,
weight_scale
,
weight_scale
,
...
@@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
...
@@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
is_k_full
=
True
,
has_zp
=
True
,
has_zp
=
True
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
is_zp_float
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment