Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2129 additions
and
144 deletions
+2129
-144
benchmarks/kernels/benchmark_mixtral_moe.py
benchmarks/kernels/benchmark_mixtral_moe.py
+71
-38
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+12
-13
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+38
-38
csrc/cache.h
csrc/cache.h
+8
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+80
-0
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+46
-46
csrc/ops.h
csrc/ops.h
+30
-5
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+78
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+4
-1
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+1
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+1
-1
csrc/pybind.cpp
csrc/pybind.cpp
+8
-1
csrc/quantization/fp8/fp8_cuda_kernels.cu
csrc/quantization/fp8/fp8_cuda_kernels.cu
+24
-1
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+1722
-0
No files found.
benchmarks/kernels/benchmark_mixtral_moe.py
View file @
1591c68f
import
argparse
import
json
import
os
import
sys
...
...
@@ -5,6 +6,7 @@ import sys
import
torch
import
torch.nn.functional
as
F
import
triton
from
tqdm
import
tqdm
from
vllm.model_executor.layers.fused_moe
import
(
fused_moe
,
get_config_file_name
)
...
...
@@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe,
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
def
main
():
def
main
(
dtype
:
str
):
method
=
fused_moe
for
bs
in
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
]:
run_grid
(
bs
,
method
=
method
)
run_grid
(
bs
,
method
=
method
,
dtype
=
dtype
)
def
run_grid
(
bs
,
method
):
def
run_grid
(
bs
,
method
,
dtype
:
str
):
d_model
=
4096
num_total_experts
=
8
top_k
=
2
...
...
@@ -34,39 +36,29 @@ def run_grid(bs, method):
num_trials
=
1
configs
=
[]
if
bs
<=
16
:
BLOCK_SIZES_M
=
[
16
]
elif
bs
<=
32
:
BLOCK_SIZES_M
=
[
16
,
32
]
elif
bs
<=
64
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
]
elif
bs
<=
128
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
]
else
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
,
256
]
for
block_size_n
in
[
32
,
64
,
128
,
256
]:
for
block_size_m
in
BLOCK_SIZES_M
:
for
block_size_m
in
[
16
,
32
,
64
,
128
,
256
]
:
for
block_size_k
in
[
64
,
128
,
256
]:
for
group_size_m
in
[
1
,
16
,
32
,
64
]:
for
num_warps
in
[
4
,
8
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
,
"GROUP_SIZE_M"
:
group_size_m
,
"num_warps"
:
num_warps
,
"num_stages"
:
4
,
})
for
num_stages
in
[
2
,
3
,
4
,
5
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
,
"GROUP_SIZE_M"
:
group_size_m
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
})
best_config
=
None
best_time_us
=
1e20
for
config
in
configs
:
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
print
(
f
'
{
config
}
'
)
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
for
config
in
tqdm
(
config
s
):
# warmup
print
(
'warming up'
)
try
:
for
_
in
range
(
num_warmup_trials
):
run_timing
(
...
...
@@ -79,12 +71,12 @@ def run_grid(bs, method):
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
dtype
=
dtype
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
continue
# trial
print
(
'benchmarking'
)
for
_
in
range
(
num_trials
):
kernel_dur_ms
=
run_timing
(
num_calls
=
num_calls
,
...
...
@@ -96,6 +88,7 @@ def run_grid(bs, method):
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
dtype
=
dtype
,
)
kernel_dur_us
=
1000
*
kernel_dur_ms
...
...
@@ -105,16 +98,18 @@ def run_grid(bs, method):
best_config
=
config
best_time_us
=
kernel_dur_us
print
(
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
'
f
'
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
'
f
'
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
tqdm
.
write
(
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
'
f
'
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
'
f
'
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
print
(
"best_time_us"
,
best_time_us
)
print
(
"best_config"
,
best_config
)
# holds Dict[str, Dict[str, int]]
filename
=
get_config_file_name
(
num_total_experts
,
model_intermediate_size
//
tp_size
)
model_intermediate_size
//
tp_size
,
"float8"
if
dtype
==
"float8"
else
None
)
print
(
f
"writing config to file
{
filename
}
"
)
existing_content
=
{}
if
os
.
path
.
exists
(
filename
):
...
...
@@ -128,27 +123,48 @@ def run_grid(bs, method):
def
run_timing
(
num_calls
:
int
,
bs
:
int
,
d_model
:
int
,
num_total_experts
:
int
,
top_k
:
int
,
tp_size
:
int
,
model_intermediate_size
:
int
,
method
,
config
)
->
float
:
config
,
dtype
:
str
)
->
float
:
shard_intermediate_size
=
model_intermediate_size
//
tp_size
hidden_states
=
torch
.
rand
(
(
bs
,
d_model
),
device
=
"cuda:0"
,
dtype
=
torch
.
b
float16
,
dtype
=
torch
.
float16
,
)
w
s
=
torch
.
rand
(
w
1
=
torch
.
rand
(
(
num_total_experts
,
2
*
shard_intermediate_size
,
d_model
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
w2
s
=
torch
.
rand
(
w2
=
torch
.
rand
(
(
num_total_experts
,
d_model
,
shard_intermediate_size
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
dtype
==
"float8"
:
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
gating_output
=
F
.
softmax
(
torch
.
rand
(
(
num_calls
,
bs
,
num_total_experts
),
device
=
hidden_states
.
device
,
...
...
@@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for
i
in
range
(
num_calls
):
hidden_states
=
method
(
hidden_states
=
hidden_states
,
w1
=
ws
,
w2
=
w2s
,
w1
=
w1
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
gating_output
=
gating_output
[
i
],
topk
=
2
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8
=
dtype
==
"float8"
,
)
end_event
.
record
()
end_event
.
synchronize
()
...
...
@@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
parser
=
argparse
.
ArgumentParser
(
prog
=
'benchmark_mixtral_moe'
,
description
=
'Benchmark and tune the fused_moe kernel'
,
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'float8'
,
'float16'
],
help
=
'Data type used for fused_moe kernel computations'
,
)
args
=
parser
.
parse_args
()
sys
.
exit
(
main
(
args
.
dtype
))
benchmarks/kernels/benchmark_paged_attention.py
View file @
1591c68f
...
...
@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def
main
(
version
:
str
,
num_seqs
:
int
,
context
_len
:
int
,
seq
_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
...
...
@@ -48,12 +48,12 @@ def main(
dtype
=
torch
.
float
,
device
=
device
)
context
_lens
=
[
context
_len
for
_
in
range
(
num_seqs
)]
max_
context
_len
=
max
(
context
_lens
)
context
_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq
_lens
=
[
seq
_len
for
_
in
range
(
num_seqs
)]
max_
seq
_len
=
max
(
seq
_lens
)
seq
_lens
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_
context
_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_
seq
_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
...
...
@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
...
...
@@ -110,9 +109,9 @@ def main(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -129,9 +128,9 @@ def main(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -166,7 +165,7 @@ if __name__ == '__main__':
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--
context-
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--
seq_
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
...
...
@@ -199,7 +198,7 @@ if __name__ == '__main__':
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context
_len
=
args
.
context
_len
,
seq
_len
=
args
.
seq
_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
...
...
csrc/attention/attention_kernels.cu
View file @
1591c68f
...
...
@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
context_len
=
context
_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
context
_len
)
{
const
int
seq_len
=
seq
_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq
_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
context
_blocks
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
seq
_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
context
_blocks
);
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
seq
_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
context
_len
);
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq
_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
...
...
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
context
_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq
_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context
_len
;
const
bool
mask
=
token_idx
>=
seq
_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
...
...
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
}
else
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
if
(
block_idx
==
num_
context
_blocks
-
1
)
{
if
(
block_idx
==
num_
seq
_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
context
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
...
...
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
...
...
@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
...
...
@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
...
...
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \
scale, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr,
\
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
...
...
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
...
...
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_
context
_len
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_
context
_len
*
sizeof
(
float
);
int
padded_max_
seq
_len
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_
seq
_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
...
...
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \
scale, \
block_tables, \
context
_lens, \
max_
context
_len, \
seq
_lens, \
max_
seq
_len, \
alibi_slopes, \
kv_scale);
...
...
@@ -746,9 +746,9 @@ void paged_attention_v1(
int
num_kv_heads
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
...
...
@@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \
scale, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
...
...
@@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_partitions);
template
<
...
...
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
...
...
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
...
...
@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \
scale, \
block_tables, \
context
_lens, \
max_
context
_len, \
seq
_lens, \
max_
seq
_len, \
alibi_slopes, \
kv_scale);
...
...
@@ -943,9 +943,9 @@ void paged_attention_v2(
int
num_kv_heads
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
...
...
csrc/cache.h
View file @
1591c68f
...
...
@@ -24,6 +24,14 @@ void reshape_and_cache(
const
std
::
string
&
kv_cache_dtype
,
const
float
kv_scale
);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
src_cache
,
...
...
csrc/cache_kernels.cu
View file @
1591c68f
...
...
@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
)
{
return
;
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_size
+
head_idx
*
head_size
+
head_offset
;
k_cache
[
tgt_value_idx
]
=
key
[
src_key_idx
];
v_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
}
}
// namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
...
...
@@ -275,6 +310,51 @@ void reshape_and_cache(
}
}
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
// FIXME: only support auto datatype, does not support fp8
if
(
kv_cache_dtype
!=
"auto"
)
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
k_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
k_cache
.
stride
(
0
);
TORCH_CHECK
(
k_cache
.
stride
(
0
)
==
v_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_flash"
,
[
&
]
{
vllm
::
reshape_and_cache_flash_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
k_cache
.
data_ptr
<
scalar_t
>
(),
v_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
block_stride
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
);
});
}
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
>
...
...
csrc/cpu/attention.cpp
View file @
1591c68f
...
...
@@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE
std
::
pair
<
T
,
T
>
reduceSoftmaxAlibi
(
T
*
data
,
const
int
size
,
const
int
capacity
,
const
float
alibi_slope
,
const
int
start_index
,
const
int
context
_len
)
{
data
[
0
]
+=
alibi_slope
*
(
start_index
-
context
_len
+
1
);
const
int
seq
_len
)
{
data
[
0
]
+=
alibi_slope
*
(
start_index
-
seq
_len
+
1
);
T
max
=
data
[
0
];
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
context
_len
+
1
);
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
seq
_len
+
1
);
data
[
i
]
=
qk
;
max
=
max
>=
qk
?
max
:
qk
;
}
...
...
@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
...
@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert
(
BLOCK_SIZE
==
16
);
int
max_
context
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
context
_len_padded
=
(
max_
context
_len
+
15
)
&
0xFFFFFFF0
;
TORCH_CHECK
((
max_
context
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
int
max_
seq
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
seq
_len_padded
=
(
max_
seq
_len
+
15
)
&
0xFFFFFFF0
;
TORCH_CHECK
((
max_
seq
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
const
int
parallel_work_item_num
=
omp_get_max_threads
();
size_t
logits_bytes
=
parallel_work_item_num
*
max_
context
_len_padded
*
sizeof
(
float
);
parallel_work_item_num
*
max_
seq
_len_padded
*
sizeof
(
float
);
float
*
logits
=
(
float
*
)
std
::
aligned_alloc
(
64
,
logits_bytes
);
// Cacheline alignment for each context token.
// [parallel_work_item_num, max_
context
_len_padded]
// [parallel_work_item_num, max_
seq
_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
int
context_len
=
context
_lens
[
seq_idx
];
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
;
const
int
block_num
=
(
context
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
block_num
=
(
seq
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int64_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
scalar_t
*
__restrict__
q_vec_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
int
last_block_token_num
=
context
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
seq
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
float
*
__restrict__
thread_block_logits
=
logits
+
omp_get_thread_num
()
*
max_
context
_len_padded
;
logits
+
omp_get_thread_num
()
*
max_
seq
_len_padded
;
// Compute logits
for
(
int
block_idx
=
0
;
block_idx
<
block_num
;
++
block_idx
)
{
...
...
@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax
if
(
alibi_slopes
)
{
reduceSoftmaxAlibi
(
thread_block_logits
,
context
_len
,
reduceSoftmaxAlibi
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
0
,
context
_len
);
seq
_len
);
}
else
{
reduceSoftmax
(
thread_block_logits
,
context
_len
,
reduceSoftmax
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
);
}
...
...
@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr,
context
_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr,
seq
_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
...
...
@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void
paged_attention_v1_impl_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
case
64
:
...
...
@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context
_lens, max_
context
_len, alibi_slopes);
seq
_lens, max_
seq
_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
...
...
@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
...
@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
...
@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for
(
int
partition_idx
=
0
;
partition_idx
<
max_num_partitions
;
++
partition_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
start_token_idx
=
partition_idx
*
PARTITION_SIZE
;
if
(
start_token_idx
>=
context
_len
)
if
(
start_token_idx
>=
seq
_len
)
continue
;
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
const
bool
no_reduce
=
(
partition_num
==
1
);
const
int
context_
token_num
=
(
std
::
min
(
context
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
const
int
token_num
=
(
std
::
min
(
seq
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
start_token_idx
);
const
int
block_num
=
(
context_
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
(
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
last_block_token_num
=
context_
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
+
start_token_idx
/
BLOCK_SIZE
;
...
...
@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std
::
pair
<
float
,
float
>
max_and_sum
;
if
(
alibi_slopes
)
{
max_and_sum
=
reduceSoftmaxAlibi
(
logits
,
context_
token_num
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
start_token_idx
,
context
_len
);
logits
,
token_num
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
start_token_idx
,
seq
_len
);
}
else
{
max_and_sum
=
reduceSoftmax
(
logits
,
context_
token_num
,
max_and_sum
=
reduceSoftmax
(
logits
,
token_num
,
block_num
*
BLOCK_SIZE
);
}
...
...
@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
continue
;
...
...
@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
group_idx
=
0
;
group_idx
<
head_group_num
;
++
group_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
continue
;
...
...
@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
...
...
@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
case
64
:
...
...
@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables,
context
_lens, block_size, \
max_
context
_len, alibi_slopes);
num_kv_heads, scale, block_tables,
seq
_lens, block_size, \
max_
seq
_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
...
...
@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
...
csrc/ops.h
View file @
1591c68f
...
...
@@ -10,9 +10,9 @@ void paged_attention_v1(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
...
...
@@ -28,9 +28,9 @@ void paged_attention_v2(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
...
...
@@ -124,6 +124,26 @@ torch::Tensor marlin_gemm(
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
#endif
void
squeezellm_gemm
(
...
...
@@ -146,7 +166,12 @@ void gptq_shuffle(
torch
::
Tensor
q_perm
,
int
bit
);
void
scaled_fp8_quant
(
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
...
...
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_config.h
View file @
1591c68f
...
...
@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
...
...
@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
View file @
1591c68f
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
View file @
1591c68f
...
...
@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
if
constexpr
(
feat_in
<
=
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
...
...
@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
View file @
1591c68f
...
...
@@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
# noqa: E501
for
input_dtype
in
DTYPES
:
...
...
csrc/punica/punica_ops.cc
View file @
1591c68f
...
...
@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW
(
CASE
,
_
,
_
,
_
)
FOR_INST_BGMV_WIDE_NARROW
(
CASE_ONESIDE
,
_
,
_
,
_
)
#undef CASE
#undef CASE_ONESIDE
default:
return
false
;
}
return
true
;
}
...
...
csrc/pybind.cpp
View file @
1591c68f
...
...
@@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"scaled_fp8_quant"
,
&
scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
...
...
@@ -93,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
...
...
csrc/quantization/fp8/fp8_cuda_kernels.cu
View file @
1591c68f
...
...
@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
}
// namespace vllm
void
scaled_fp8_quant
(
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
dim3
grid
(
num_tokens
);
dim3
block
(
1024
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
0 → 100644
View file @
1591c68f
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "gptq_marlin.cuh"
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
gptq_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
template
<
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks with
// a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant_4bit
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
__device__
inline
FragB
dequant_8bit
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Same as above, but for act_order (each K is multiplied individually)
__device__
inline
void
scale4
(
FragB
&
frag_b
,
FragS
&
frag_s_1
,
FragS
&
frag_s_2
,
FragS
&
frag_s_3
,
FragS
&
frag_s_4
,
int
i
)
{
__half2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_2
)[
i
];
__half2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Given 2 floats multiply by 2 scales (halves)
__device__
inline
void
scale_float
(
float
*
c
,
FragS
&
s
)
{
__half
*
s_ptr
=
reinterpret_cast
<
__half
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
__half2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
__half2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
default_threads
;
int
rest
=
size_k
%
default_threads
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
default_threads
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
template
<
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks with
// a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
constexpr
int
pack_factor
=
32
/
num_bits
;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
div_ceil
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
if
constexpr
(
num_bits
==
4
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
(
b_quant
);
frag_b1
=
dequant_4bit
(
b_quant_shift
);
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
(
b_quant_0
);
frag_b1
=
dequant_8bit
(
b_quant_1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up the
// compiler and lead to slowdowns, hence we also use async-copies even though
// these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
half2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
else
{
if
(
last
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
{
if
(
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default (max cache usage)
{
64
,
128
,
128
},
// Reduce N, reduce warps
{
128
,
64
,
128
},
// Reduce N more, but increase K
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_k
=
th_config
.
thread_k
;
// Get max scale groups per thread-block
int
tb_groups
;
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
div_ceil
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
div_ceil
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
for
(
auto
th_config
:
thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
printf
(
"WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
"GPU cache. This may "
"hurt performance. Consider upgrading your GPU.
\n
"
);
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
has_act_order
)
{
// Permute A columns
int
block_rows
=
div_ceil
(
prob_m
,
blocks
);
permute_cols_kernel
<<<
blocks
,
default_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by having
// a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Verify g_idx and perm
TORCH_CHECK
((
g_idx
.
size
(
0
)
==
0
&&
perm
.
size
(
0
)
==
0
)
||
(
g_idx
.
size
(
0
)
==
size_k
&&
perm
.
size
(
0
)
==
size_k
),
"Unexpected g_idx.size(0) = "
,
g_idx
.
size
(
0
),
" and perm.size(0) = "
,
perm
.
size
(
0
),
", where size_k = "
,
size_k
);
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
2
,
"b_scales rank = "
,
b_rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
// Verify workspace size
TORCH_CHECK
(
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
gptq_marlin
::
marlin_mm_f16i4
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
return
c
;
}
#endif
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment