Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
d26f4c73
Commit
d26f4c73
authored
May 27, 2024
by
gaoqiong
Browse files
增加awq模块
parent
2326380c
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
599 additions
and
289 deletions
+599
-289
CMakeLists.txt
CMakeLists.txt
+1
-1
lmdeploy/cli/cli.py
lmdeploy/cli/cli.py
+5
-1
lmdeploy/turbomind/deploy/converter.py
lmdeploy/turbomind/deploy/converter.py
+5
-0
lmdeploy/turbomind/deploy/target_model/base.py
lmdeploy/turbomind/deploy/target_model/base.py
+12
-0
lmdeploy/turbomind/deploy/target_model/w4.py
lmdeploy/turbomind/deploy/target_model/w4.py
+85
-7
lmdeploy/turbomind/turbomind.py
lmdeploy/turbomind/turbomind.py
+4
-0
setup.py
setup.py
+12
-0
src/turbomind/kernels/CMakeLists.txt
src/turbomind/kernels/CMakeLists.txt
+1
-1
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
+3
-1
src/turbomind/kernels/gemm_s_f16/common.h
src/turbomind/kernels/gemm_s_f16/common.h
+37
-39
src/turbomind/kernels/gemm_s_f16/format.cu
src/turbomind/kernels/gemm_s_f16/format.cu
+105
-1
src/turbomind/kernels/gemm_s_f16/format.h
src/turbomind/kernels/gemm_s_f16/format.h
+11
-0
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
+4
-0
src/turbomind/kernels/gemm_s_f16/gemm_template.h
src/turbomind/kernels/gemm_s_f16/gemm_template.h
+215
-214
src/turbomind/models/llama/BlockManager.cc
src/turbomind/models/llama/BlockManager.cc
+1
-0
src/turbomind/models/llama/CMakeLists.txt
src/turbomind/models/llama/CMakeLists.txt
+5
-3
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
+85
-18
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
+1
-0
src/turbomind/models/llama/LlamaDenseWeight.h
src/turbomind/models/llama/LlamaDenseWeight.h
+1
-0
src/turbomind/models/llama/LlamaFfnLayer.cc
src/turbomind/models/llama/LlamaFfnLayer.cc
+6
-3
No files found.
CMakeLists.txt
View file @
d26f4c73
...
@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
...
@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
# $<TARGET_OBJECTS:flash_attention2>
# $<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:LlamaTritonBackend>
#
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
$<TARGET_OBJECTS:TransformerTritonBackend>
...
...
lmdeploy/cli/cli.py
View file @
d26f4c73
...
@@ -61,7 +61,11 @@ class CLI(object):
...
@@ -61,7 +61,11 @@ class CLI(object):
default
=
0
,
default
=
0
,
help
=
'A parameter used in awq to quantize fp16 weights '
help
=
'A parameter used in awq to quantize fp16 weights '
'to 4 bits'
)
'to 4 bits'
)
parser
.
add_argument
(
'--w4-weight-layout'
,
type
=
int
,
default
=
2
,
help
=
'A parameter used in AWQ to control the layout of weight '
)
parser
.
set_defaults
(
run
=
CLI
.
convert
)
parser
.
set_defaults
(
run
=
CLI
.
convert
)
@
staticmethod
@
staticmethod
...
...
lmdeploy/turbomind/deploy/converter.py
View file @
d26f4c73
...
@@ -196,6 +196,7 @@ def main(model_name: str,
...
@@ -196,6 +196,7 @@ def main(model_name: str,
tp
:
int
=
1
,
tp
:
int
=
1
,
quant_path
:
str
=
None
,
quant_path
:
str
=
None
,
group_size
:
int
=
0
,
group_size
:
int
=
0
,
w4_weight_layout
:
int
=
2
,
**
kwargs
):
**
kwargs
):
"""deploy llama family models via turbomind.
"""deploy llama family models via turbomind.
...
@@ -215,6 +216,7 @@ def main(model_name: str,
...
@@ -215,6 +216,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None.
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
to 4 bits
w4_weight_layout (int) :a parameter used in AWQ to control the layout of weight
kwargs (dict): other params for convert
kwargs (dict): other params for convert
"""
"""
...
@@ -260,10 +262,13 @@ def main(model_name: str,
...
@@ -260,10 +262,13 @@ def main(model_name: str,
cfg
.
tensor_para_size
=
tp
cfg
.
tensor_para_size
=
tp
cfg
.
rotary_embedding
=
cfg
.
size_per_head
cfg
.
rotary_embedding
=
cfg
.
size_per_head
cfg
.
group_size
=
group_size
cfg
.
group_size
=
group_size
cfg
.
w4_weight_layout
=
w4_weight_layout
if
inferred_model_format
.
find
(
'awq'
)
!=
-
1
:
if
inferred_model_format
.
find
(
'awq'
)
!=
-
1
:
cfg
.
weight_type
=
'int4'
cfg
.
weight_type
=
'int4'
output_format
=
'w4'
output_format
=
'w4'
assert
group_size
>
0
,
f
'group_size:
{
group_size
}
should > 0'
assert
group_size
>
0
,
f
'group_size:
{
group_size
}
should > 0'
print
(
"w4_weight_layout:"
,
w4_weight_layout
)
assert
w4_weight_layout
>=
0
and
w4_weight_layout
<
3
,
f
'w4_weight_layout:
{
w4_weight_layout
}
should >= 0 and < 3'
else
:
else
:
#output_format = update_output_format(model_name, inferred_model_format,
#output_format = update_output_format(model_name, inferred_model_format,
# model_path, output_format)
# model_path, output_format)
...
...
lmdeploy/turbomind/deploy/target_model/base.py
View file @
d26f4c73
...
@@ -5,6 +5,7 @@ import inspect
...
@@ -5,6 +5,7 @@ import inspect
import
io
import
io
import
json
import
json
import
os.path
as
osp
import
os.path
as
osp
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
configparser
import
ConfigParser
from
configparser
import
ConfigParser
...
@@ -52,6 +53,7 @@ class TurbomindModelConfig:
...
@@ -52,6 +53,7 @@ class TurbomindModelConfig:
rope_theta
:
float
=
10000.0
rope_theta
:
float
=
10000.0
size_per_head
:
int
=
128
size_per_head
:
int
=
128
group_size
:
int
=
0
group_size
:
int
=
0
w4_weight_layout
:
int
=
2
max_batch_size
:
int
=
64
max_batch_size
:
int
=
64
max_context_token_num
:
int
=
1
max_context_token_num
:
int
=
1
step_length
:
int
=
1
step_length
:
int
=
1
...
@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
...
@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
self
.
to_file
=
to_file
self
.
to_file
=
to_file
self
.
out_dir
=
out_dir
self
.
out_dir
=
out_dir
self
.
tm_params
=
{}
self
.
tm_params
=
{}
#self.weight_layout= 1
#获取环境变量
#env_weight_layout = os.environ.get('LMDEPLOY_WEIGHTLAYOUT_SWITCH', '1')
#self.weight_layout =int(env_weight_layout)
#print("self.weight_layout:",self.weight_layout)
@
abstractmethod
@
abstractmethod
def
get_config
(
self
,
cfg
:
TurbomindModelConfig
)
->
TurbomindModelConfig
:
def
get_config
(
self
,
cfg
:
TurbomindModelConfig
)
->
TurbomindModelConfig
:
...
@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
...
@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
return
x
.
view
(
n_heads
,
2
,
dim
//
n_heads
//
2
,
return
x
.
view
(
n_heads
,
2
,
dim
//
n_heads
//
2
,
1
).
transpose
(
1
,
2
).
reshape
(
dim
,
1
)
1
).
transpose
(
1
,
2
).
reshape
(
dim
,
1
)
def
permute_trans
(
x
:
torch
.
Tensor
):
if
x
.
shape
[
-
1
]
>
1
:
dim
=
x
.
shape
[
-
1
]
return
x
.
view
(
-
1
,
x
.
shape
[
-
1
]).
transpose
(
0
,
1
).
reshape
(
dim
,
-
1
)
def
merge_qkv
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
tp
:
int
,
def
merge_qkv
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
tp
:
int
,
dim
:
int
):
dim
:
int
):
...
...
lmdeploy/turbomind/deploy/target_model/w4.py
View file @
d26f4c73
...
@@ -8,7 +8,7 @@ import lmdeploy
...
@@ -8,7 +8,7 @@ import lmdeploy
from
..source_model.base
import
BaseInputModel
,
BaseReader
from
..source_model.base
import
BaseInputModel
,
BaseReader
from
.base
import
(
OUTPUT_MODELS
,
BaseOutputModel
,
TurbomindModelConfig
,
from
.base
import
(
OUTPUT_MODELS
,
BaseOutputModel
,
TurbomindModelConfig
,
merge_qkv
,
permute
)
merge_qkv
,
permute
,
permute_trans
)
# import _turbomind as _tm
# import _turbomind as _tm
# TODO: find another way import _turbomind
# TODO: find another way import _turbomind
...
@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
...
@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
return
_qw
,
_sz
return
_qw
,
_sz
def
convert_s4_
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
assert
qw
.
is_contiguous
()
assert
qz
.
is_contiguous
()
assert
s
.
is_contiguous
()
_qw
=
torch
.
zeros_like
(
qw
)
_sz
=
torch
.
zeros_like
(
s
,
dtype
=
torch
.
int32
)
# half2
_ws
=
torch
.
zeros_like
(
s
)
_tm
.
convert_s4_k_m8_
(
_qw
,
_sz
,
_ws
,
qw
,
s
,
qz
,
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
return
_qw
,
_sz
def
tp_m_s4
(
x
:
torch
.
Tensor
,
tp
:
int
):
def
tp_m_s4
(
x
:
torch
.
Tensor
,
tp
:
int
):
return
x
.
view
(
x
.
size
(
0
)
//
32
,
tp
,
-
1
,
128
).
permute
(
0
,
2
,
3
,
return
x
.
view
(
x
.
size
(
0
)
//
32
,
tp
,
-
1
,
128
).
permute
(
0
,
2
,
3
,
...
@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
"""Export transformer layer i."""
"""Export transformer layer i."""
group_size
=
self
.
cfg
.
group_size
group_size
=
self
.
cfg
.
group_size
tp
=
self
.
cfg
.
tensor_para_size
tp
=
self
.
cfg
.
tensor_para_size
w4_weight_layout
=
self
.
cfg
.
w4_weight_layout
size_per_head
=
self
.
cfg
.
size_per_head
size_per_head
=
self
.
cfg
.
size_per_head
# attn
# attn
q_qw
,
k_qw
,
v_qw
,
o_qw
=
get_cuda_tensor
(
bin
.
attn
(
i
))
q_qw
,
k_qw
,
v_qw
,
o_qw
=
get_cuda_tensor
(
bin
.
attn
(
i
))
...
@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
qkv_qz
=
merge_qkv
(
q_qz
,
k_qz
,
v_qz
,
tp
,
dim
=
2
)
qkv_qz
=
merge_qkv
(
q_qz
,
k_qz
,
v_qz
,
tp
,
dim
=
2
)
qkv_s
=
merge_qkv
(
q_s
,
k_s
,
v_s
,
tp
,
dim
=
2
)
qkv_s
=
merge_qkv
(
q_s
,
k_s
,
v_s
,
tp
,
dim
=
2
)
qkv_qw
,
qkv_sz
=
convert_s4
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
pad_group_count
=
2
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
if
qkv_qw
.
shape
[
0
]
%
4096
==
0
:
qkv_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
qkv_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
qkv_qw
=
torch
.
cat
((
qkv_qw
,
qkv_qw_padding
),
dim
=
0
).
contiguous
()
qkv_qz_padding
=
torch
.
zeros
(
pad_group_count
,
qkv_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
qkv_qz
=
torch
.
cat
((
qkv_qz
,
qkv_qz_padding
),
dim
=
0
).
contiguous
()
qkv_s_padding
=
torch
.
zeros
(
pad_group_count
,
qkv_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
qkv_s
=
torch
.
cat
((
qkv_s
,
qkv_s_padding
),
dim
=
0
).
contiguous
()
qkv_qw
,
qkv_sz
=
convert_s4_
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
qkv_sz
=
permute_trans
(
qkv_sz
)
else
:
qkv_qw
,
qkv_sz
=
convert_s4
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
#print("请设置weight layout\n")
self
.
save_split
(
qkv_qw
,
f
'layers.
{
i
}
.attention.w_qkv.qweight'
,
-
1
)
self
.
save_split
(
qkv_qw
,
f
'layers.
{
i
}
.attention.w_qkv.qweight'
,
-
1
)
self
.
save_split
(
qkv_sz
,
f
'layers.
{
i
}
.attention.w_qkv.scales_zeros'
,
-
1
)
self
.
save_split
(
qkv_sz
,
f
'layers.
{
i
}
.attention.w_qkv.scales_zeros'
,
-
1
)
o_qw
,
o_sz
=
convert_s4
(
o_qw
,
o_qz
,
o_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
if
o_qw
.
shape
[
0
]
%
4096
==
0
:
o_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
o_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
o_qw
=
torch
.
cat
((
o_qw
,
o_qw_padding
),
dim
=
0
).
contiguous
()
o_qz_padding
=
torch
.
zeros
(
pad_group_count
,
o_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
o_qz
=
torch
.
cat
((
o_qz
,
o_qz_padding
),
dim
=
0
).
contiguous
()
o_s_padding
=
torch
.
zeros
(
pad_group_count
,
o_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
o_s
=
torch
.
cat
((
o_s
,
o_s_padding
),
dim
=
0
).
contiguous
()
o_qw
,
o_sz
=
convert_s4_
(
o_qw
,
o_qz
,
o_s
,
group_size
)
o_sz
=
permute_trans
(
o_sz
)
else
:
o_qw
,
o_sz
=
convert_s4
(
o_qw
,
o_qz
,
o_s
,
group_size
)
self
.
save_split
(
o_qw
,
f
'layers.
{
i
}
.attention.wo.qweight'
,
0
)
self
.
save_split
(
o_qw
,
f
'layers.
{
i
}
.attention.wo.qweight'
,
0
)
self
.
save_split
(
o_sz
,
f
'layers.
{
i
}
.attention.wo.scales_zeros'
,
0
)
self
.
save_split
(
o_sz
,
f
'layers.
{
i
}
.attention.wo.scales_zeros'
,
0
)
...
@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
w13_qw
,
w13_qz
,
w13_s
=
fuse_w1_w3_s4
(
w1_qw
,
w1_qz
,
w1_s
,
w3_qw
,
w3_qz
,
w13_qw
,
w13_qz
,
w13_s
=
fuse_w1_w3_s4
(
w1_qw
,
w1_qz
,
w1_s
,
w3_qw
,
w3_qz
,
w3_s
)
w3_s
)
w13_qw
,
w13_sz
=
convert_s4
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
if
w13_qw
.
shape
[
0
]
%
4096
==
0
:
w13_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
w13_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w13_qw
=
torch
.
cat
((
w13_qw
,
w13_qw_padding
),
dim
=
0
).
contiguous
()
w13_qz_padding
=
torch
.
zeros
(
pad_group_count
,
w13_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w13_qz
=
torch
.
cat
((
w13_qz
,
w13_qz_padding
),
dim
=
0
).
contiguous
()
w13_s_padding
=
torch
.
zeros
(
pad_group_count
,
w13_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
w13_s
=
torch
.
cat
((
w13_s
,
w13_s_padding
),
dim
=
0
).
contiguous
()
w13_qw
,
w13_sz
=
convert_s4_
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
w13_sz
=
permute_trans
(
w13_sz
)
else
:
w13_qw
,
w13_sz
=
convert_s4
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
self
.
save_split
(
w13_qw
,
f
'layers.
{
i
}
.feed_forward.w13.qweight'
,
-
1
)
self
.
save_split
(
w13_qw
,
f
'layers.
{
i
}
.feed_forward.w13.qweight'
,
-
1
)
self
.
save_split
(
w13_sz
,
f
'layers.
{
i
}
.feed_forward.w13.scales_zeros'
,
self
.
save_split
(
w13_sz
,
f
'layers.
{
i
}
.feed_forward.w13.scales_zeros'
,
-
1
)
-
1
)
w2_qw
,
w2_sz
=
convert_s4
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
#pading
if
w2_qw
.
shape
[
0
]
%
4096
==
0
:
w2_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
w2_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w2_qw
=
torch
.
cat
((
w2_qw
,
w2_qw_padding
),
dim
=
0
).
contiguous
()
w2_qz_padding
=
torch
.
zeros
(
pad_group_count
,
w2_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w2_qz
=
torch
.
cat
((
w2_qz
,
w2_qz_padding
),
dim
=
0
).
contiguous
()
w2_s_padding
=
torch
.
zeros
(
pad_group_count
,
w2_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
w2_s
=
torch
.
cat
((
w2_s
,
w2_s_padding
),
dim
=
0
).
contiguous
()
w2_qw
,
w2_sz
=
convert_s4_
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
w2_sz
=
permute_trans
(
w2_sz
)
else
:
w2_qw
,
w2_sz
=
convert_s4
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
self
.
save_split
(
w2_qw
,
f
'layers.
{
i
}
.feed_forward.w2.qweight'
,
0
)
self
.
save_split
(
w2_qw
,
f
'layers.
{
i
}
.feed_forward.w2.qweight'
,
0
)
self
.
save_split
(
w2_sz
,
f
'layers.
{
i
}
.feed_forward.w2.scales_zeros'
,
0
)
self
.
save_split
(
w2_sz
,
f
'layers.
{
i
}
.feed_forward.w2.scales_zeros'
,
0
)
...
...
lmdeploy/turbomind/turbomind.py
View file @
d26f4c73
...
@@ -147,6 +147,7 @@ class TurboMind:
...
@@ -147,6 +147,7 @@ class TurboMind:
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_format
:
Optional
[
str
]
=
None
,
model_format
:
Optional
[
str
]
=
None
,
group_size
:
Optional
[
int
]
=
None
,
group_size
:
Optional
[
int
]
=
None
,
w4_weight_layout
:
Optional
[
int
]
=
None
,
tp
:
Optional
[
int
]
=
None
,
tp
:
Optional
[
int
]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -179,6 +180,7 @@ class TurboMind:
...
@@ -179,6 +180,7 @@ class TurboMind:
engine_config
=
_update_engine_config
(
engine_config
,
engine_config
=
_update_engine_config
(
engine_config
,
model_format
=
model_format
,
model_format
=
model_format
,
group_size
=
group_size
,
group_size
=
group_size
,
w4_weight_layout
=
w4_weight_layout
,
tp
=
tp
,
tp
=
tp
,
**
kwargs
)
**
kwargs
)
...
@@ -304,6 +306,7 @@ class TurboMind:
...
@@ -304,6 +306,7 @@ class TurboMind:
output_format
=
'w4'
output_format
=
'w4'
data_type
=
'int4'
data_type
=
'int4'
cfg
.
group_size
=
128
cfg
.
group_size
=
128
cfg
.
w4_weight_layout
=
2
else
:
else
:
# output_format = update_output_format(cfg.model_name,
# output_format = update_output_format(cfg.model_name,
# inferred_model_format,
# inferred_model_format,
...
@@ -378,6 +381,7 @@ class TurboMind:
...
@@ -378,6 +381,7 @@ class TurboMind:
self
.
config
=
cfg
self
.
config
=
cfg
self
.
model_name
=
cfg
.
model_name
self
.
model_name
=
cfg
.
model_name
self
.
data_type
=
cfg
.
weight_type
self
.
data_type
=
cfg
.
weight_type
print
(
"from_workspace_cfg:"
,
cfg
)
# create model
# create model
logger
.
warning
(
f
'model_config:
\n\n
{
cfg
.
toini
()
}
'
)
logger
.
warning
(
f
'model_config:
\n\n
{
cfg
.
toini
()
}
'
)
...
...
setup.py
View file @
d26f4c73
...
@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
file
.
writelines
(
lines
)
file
.
writelines
(
lines
)
file
.
close
()
file
.
close
()
def
copy_ck_so
():
lmdeploy_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
so_path
=
os
.
path
.
join
(
os
.
path
.
join
(
lmdeploy_root
,
"3rdparty"
),
"libgemm_multiB_int4.so"
)
# dtk version
if
os
.
getenv
(
"ROCM_PATH"
):
rocm_path
=
os
.
getenv
(
'ROCM_PATH'
,
""
)
rocm_so_path
=
os
.
path
.
join
(
rocm_path
,
'lib'
)
print
(
"rocm_so_path:"
,
rocm_so_path
)
shutil
.
copy
(
so_path
,
rocm_so_path
)
else
:
shutil
.
copy
(
so_path
,
"usr/local/lib"
)
def
get_version
():
def
get_version
():
copy_ck_so
()
get_version_add
()
get_version_add
()
version_file
=
'lmdeploy/version.py'
version_file
=
'lmdeploy/version.py'
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
...
...
src/turbomind/kernels/CMakeLists.txt
View file @
d26f4c73
...
@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
...
@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#
add_subdirectory(gemm_s_f16)
add_subdirectory
(
gemm_s_f16
)
add_subdirectory
(
decoder_multihead_attention
)
add_subdirectory
(
decoder_multihead_attention
)
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
View file @
d26f4c73
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
add_library
(
gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu
)
add_library
(
gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu
../../models/llama/awq_sugon/gemm_w4_dequation.cu
)
target_compile_options
(
gemm_s4_f16 PRIVATE
target_compile_options
(
gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr
)
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr
)
set_property
(
TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON
)
set_property
(
TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON
)
...
...
src/turbomind/kernels/gemm_s_f16/common.h
View file @
d26f4c73
...
@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0])
// : "=r"(h[0])
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
h
[
0
]
=
(
i4s
&
BOTTOM_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1])
// : "=r"(h[1])
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
h
[
1
]
=
(
i4s
&
TOP_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2])
// : "=r"(h[2])
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
h
[
2
]
=
(
top_i4s
&
BOTTOM_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3])
// : "=r"(h[3])
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf
(
"=========common.h 86
\n
"
);
h
[
3
]
=
(
top_i4s
&
TOP_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// I use inline PTX below because I am not sure if the compiler will emit
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// performance reliability over code readability.
...
@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers.
// Finally, we construct the output numbers.
// Convert elt_01
// Convert elt_01
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_23
h
[
0
]
=
h
[
0
]
-
FP16_TOP_MAGIC_NUM
;
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_23
// // Convert elt_45
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h
[
1
]
=
h
[
1
]
*
ONE_SIXTEENTH
+
NEG_64
;
// // Convert elt_67
// Convert elt_45
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h
[
2
]
=
h
[
2
]
-
FP16_TOP_MAGIC_NUM
;
// Convert elt_67
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h
[
3
]
=
h
[
3
]
*
ONE_SIXTEENTH
+
NEG_64
;
return
result
;
return
result
;
}
}
...
@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
...
@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
const
uint32_t
top_i4s
=
i4s
>>
8
;
printf
(
"=========common.h 133
\n
"
);
// if (0) { // 1024 & 64
// 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
h
[
0
]
=
(
i4s
&
BOT_MASK
)
|
MAGIC_NUM_2
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h
[
1
]
=
(
i4s
&
TOP_MASK
)
|
MAGIC_NUM_1
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
h
[
2
]
=
(
top_i4s
&
BOT_MASK
)
|
MAGIC_NUM_2
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h
[
3
]
=
(
top_i4s
&
TOP_MASK
)
|
MAGIC_NUM_1
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
h
[
0
]
<<=
4
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
h
[
2
]
<<=
4
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
// else { // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// h[0] <<= 4;
// h[2] <<= 4;
// // we don't need to subtract the magic nums because zeros will go through the same dequant function
// // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
return
result
;
return
result
;
}
}
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
{
uint32_t
smem_int_ptr
;
uint32_t
smem_int_ptr
;
...
@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
...
@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
{
{
uint
s
,
z
;
//
uint s, z;
(
half2
&
)
z
=
__halves2half2
(
q
.
x
,
q
.
x
);
//
(half2&)z = __halves2half2(q.x, q.x);
(
half2
&
)
s
=
__halves2half2
(
q
.
y
,
q
.
y
);
//
(half2&)s = __halves2half2(q.y, q.y);
auto
&
t
=
(
const
uint
&
)
x
;
//
auto& t = (const uint&)x;
uint
u
,
v
;
uint
v
;
// if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// }
// }
...
@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
...
@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// }
// }
printf
(
"=========common.h 235
\n
"
);
return
(
half2
&
)
v
;
return
(
half2
&
)
v
;
}
}
...
...
src/turbomind/kernels/gemm_s_f16/format.cu
View file @
d26f4c73
// Copyright (c) OpenMMLab. All rights reserved.
// Copyright (c) OpenMMLab. All rights reserved.
#include "common.h"
#include "common.h"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"
#include <iostream>
#include <iostream>
namespace
turbomind
{
namespace
turbomind
{
...
@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
...
@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
// permutation for [k, m/8] layout
// permutation for [k, m/8] layout
Array
<
int
,
10
>
shape
{
k
/
32
,
2
,
2
,
4
,
2
,
m
/
32
,
2
,
2
,
2
,
4
};
Array
<
int
,
10
>
shape
{
k
/
32
,
2
,
2
,
4
,
2
,
m
/
32
,
2
,
2
,
2
,
4
};
// |warp| lane | 2x2 | a0-7 |
// |warp| lane | 2x2 | a0-7 |
permute_u4
<
0
,
5
,
9
,
8
,
3
,
1
,
6
,
4
,
2
,
7
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
//permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
}
void
reformat_s4_k_m8_tarnsw4
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
)
{
// permutation for [k, m/8] layout
Array
<
int
,
10
>
shape
{
1
,
k
/
8
,
2
,
2
,
2
,
1
,
m
/
8
,
2
,
2
,
2
};
// 0123456-->4,6,7,5,0,3,1,2
//permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4
<
5
,
6
,
8
,
9
,
7
,
0
,
1
,
4
,
2
,
3
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
}
}
__global__
void
dequantize_s4_offset_64
(
uint4
*
dst
,
const
uint32_t
*
src
,
size_t
count
)
__global__
void
dequantize_s4_offset_64
(
uint4
*
dst
,
const
uint32_t
*
src
,
size_t
count
)
...
@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
...
@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
reformat_s4_k_m8
(
A_dst
,
A_src
,
m
,
k
,
st
);
reformat_s4_k_m8
(
A_dst
,
A_src
,
m
,
k
,
st
);
}
}
void
convert_s4_k_m8_
(
uint32_t
*
A_dst
,
half2
*
Q_dst
,
half
*
workspace
,
const
uint32_t
*
A_src
,
const
half
*
scales
,
const
uint32_t
*
qzeros
,
int
m
,
int
k
,
int
group_size
,
cudaStream_t
st
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
st
>>>
((
uint4
*
)
workspace
,
qzeros
,
k
/
group_size
*
m
/
8
);
merge_Q
<<<
256
,
256
,
0
,
st
>>>
(
Q_dst
,
scales
,
workspace
,
k
/
group_size
*
m
);
reformat_s4_k_m8_tarnsw4
(
A_dst
,
A_src
,
m
,
k
,
st
);
}
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
)
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
)
{
{
Array
<
int
,
7
>
shape
{
k
,
m
/
size_per_head
,
2
,
size_per_head
/
2
/
8
,
2
,
2
,
2
};
Array
<
int
,
7
>
shape
{
k
,
m
/
size_per_head
,
2
,
size_per_head
/
2
/
8
,
2
,
2
,
2
};
...
@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
...
@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
{
{
dequantize_s4_kernel
<<<
512
,
512
>>>
(
dst
,
src
,
count
);
dequantize_s4_kernel
<<<
512
,
512
>>>
(
dst
,
src
,
count
);
}
}
__global__
void
dequant_kernel
(
int
num_kernels
,
half
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int
j
=
id
%
n
;
int
i
=
id
/
n
;
half
x
=
zeros_and_scales
[
i
/
group_size
*
n
+
j
].
data
[
0
];
half
y
=
zeros_and_scales
[
i
/
group_size
*
n
+
j
].
data
[
1
];
float
tmp
=
(
weight
[
id
]
-
x
)
*
y
;
weight
[
id
]
=
__float2half
(
tmp
);
}
__global__
void
dequant_kernel_colmajor
(
int
num_kernels
,
half
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int
j
=
id
/
group_size
;
half
x
=
zeros_and_scales
[
j
].
data
[
0
];
half
y
=
zeros_and_scales
[
j
].
data
[
1
];
float
tmp
=
(
weight
[
id
]
-
x
)
*
y
;
weight
[
id
]
=
__float2half
(
tmp
);
}
void
dequant_w4_gemm
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
stream
>>>
((
uint4
*
)
output
,
weight
,
k
*
n
/
8
);
int
num_kernels
=
k
*
n
;
dequant_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
zeros_and_scales
,
k
,
n
,
group_size
);
}
void
dequant_w4_gemm_colmajor
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
stream
>>>
((
uint4
*
)
output
,
weight
,
k
*
n
/
8
);
int
num_kernels
=
k
*
n
;
dequant_kernel_colmajor
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
zeros_and_scales
,
k
,
n
,
group_size
);
}
__global__
void
FusedSiluActivation_kernel
(
int
num_kernels
,
half
*
output
,
const
uint32_t
*
src
,
int
m
,
int
n
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
auto
data
=
((
half2
*
)
src
)[
id
];
float
x
=
__half2float
(
data
.
data
[
0
]);
float
y
=
__half2float
(
data
.
data
[
1
]);
float
silu
=
x
/
(
1.
f
+
__expf
(
-
x
))
*
y
;
output
[
id
]
=
__float2half
(
silu
);
}
__global__
void
assign_kernel
(
int
num_kernels
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
output
[
id
]
=
src
[
id
];
}
void
addFusedSiluActivation
(
cudaStream_t
stream
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
,
int
type
)
{
int
num_kernels
=
m
*
n
;
switch
(
type
)
{
case
0
:
assign_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
src
,
m
,
n
);
break
;
case
1
:
FusedSiluActivation_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
int
(
num_kernels
/
2
),
output
,(
const
uint32_t
*
)
src
,
m
,
n
);
break
;
default:
return
;
}
}
}
// namespace turbomind
}
// namespace turbomind
src/turbomind/kernels/gemm_s_f16/format.h
View file @
d26f4c73
...
@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
...
@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
int
group_size
,
int
group_size
,
cudaStream_t
st
=
{});
cudaStream_t
st
=
{});
void
convert_s4_k_m8_
(
uint32_t
*
A_dst
,
half2
*
Q_dst
,
half
*
workspace
,
const
uint32_t
*
A_src
,
const
half
*
scales
,
const
uint32_t
*
qzeros
,
int
m
,
int
k
,
int
group_size
,
cudaStream_t
st
=
{});
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
=
{});
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
=
{});
void
fuse_w1_w3_s4_k_m8
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
=
{});
void
fuse_w1_w3_s4_k_m8
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
=
{});
...
...
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
View file @
d26f4c73
...
@@ -13,6 +13,10 @@
...
@@ -13,6 +13,10 @@
namespace
turbomind
{
namespace
turbomind
{
extern
bool
g_dump_kernel_info_once
;
extern
bool
g_dump_kernel_info_once
;
void
dequant_w4_gemm
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
);
void
addFusedSiluActivation
(
cudaStream_t
stream
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
,
int
type
);
void
dequant_w4_gemm_colmajor
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
);
class
GemmS4F16
{
class
GemmS4F16
{
public:
public:
...
...
src/turbomind/kernels/gemm_s_f16/gemm_template.h
View file @
d26f4c73
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "common.h"
#include "common.h"
#include "cta_iterator.h"
#include "cta_iterator.h"
#include "warp_iterator.h"
#include "warp_iterator.h"
#include <cuda_pipeline_primitives.h>
//
#include <cuda_pipeline_primitives.h>
namespace
turbomind
{
namespace
turbomind
{
...
@@ -48,19 +48,19 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
...
@@ -48,19 +48,19 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
__inline__
__device__
uint
transpose_m8n8_b16_warp_shuffle
(
uint
value
,
int
lane_id
)
__inline__
__device__
uint
transpose_m8n8_b16_warp_shuffle
(
uint
value
,
int
lane_id
)
{
{
int
src_lane
=
lane_id
/
8
+
lane_id
%
4
*
8
;
//
int src_lane = lane_id / 8 + lane_id % 4 * 8;
uint
u0
=
__shfl_sync
(
0xffffffff
,
value
,
src_lane
);
//
uint u0 = __shfl_sync(0xffffffff, value, src_lane);
uint
u1
=
__shfl_sync
(
0xffffffff
,
value
,
src_lane
+
4
);
//
uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
short2
r
;
short2
r
;
if
(
lane_id
%
8
<
4
)
{
//
if (lane_id % 8 < 4) {
r
.
x
=
((
short2
&
)
u0
).
x
;
//
r.x = ((short2&)u0).x;
r
.
y
=
((
short2
&
)
u1
).
x
;
//
r.y = ((short2&)u1).x;
}
//
}
else
{
//
else {
r
.
x
=
((
short2
&
)
u0
).
y
;
//
r.x = ((short2&)u0).y;
r
.
y
=
((
short2
&
)
u1
).
y
;
//
r.y = ((short2&)u1).y;
}
//
}
return
(
uint
&
)
r
;
return
(
uint
&
)
r
;
}
}
...
@@ -87,6 +87,7 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
...
@@ -87,6 +87,7 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
// #else
// #else
// return transpose_m8n8_b16_warp_shuffle(a, lane_id);
// return transpose_m8n8_b16_warp_shuffle(a, lane_id);
// #endif
// #endif
return
a
;
}
}
namespace
ops
{
namespace
ops
{
...
@@ -158,61 +159,61 @@ struct Gemm {
...
@@ -158,61 +159,61 @@ struct Gemm {
int
&
gemm_iter
)
int
&
gemm_iter
)
{
{
constexpr
int
ITER_M
=
WARP_M
/
OP_M
;
//
constexpr int ITER_M = WARP_M / OP_M;
constexpr
int
ITER_N
=
WARP_N
/
OP_N
;
//
constexpr int ITER_N = WARP_N / OP_N;
constexpr
int
ITER_K
=
WARP_K
/
OP_K
;
//
constexpr int ITER_K = WARP_K / OP_K;
constexpr
int
kBatchA
=
(
IteratorA
::
kIterCount
+
ITER_K
-
1
)
/
ITER_K
;
//
constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
constexpr
int
kBatchQ
=
(
IteratorQ
::
kIterCount
+
ITER_K
-
1
)
/
ITER_K
;
//
constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
constexpr
int
kBatchB
=
(
IteratorB
::
kIterCount
+
ITER_K
-
1
)
/
ITER_K
;
//
constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
auto
frag_C_ptr
=
(
Array
<
float
,
4
>*
)
accum
;
// [ITER_N, ITER_M]
//
auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
iter_k
=
0
;
iter_k
<
ITER_K
;
++
iter_k
)
{
//
for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
warp_iter_A
.
load
(
warp_frag_A_
[(
iter_k
+
1
)
%
2
],
(
iter_k
+
1
)
%
ITER_K
);
//
warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
warp_iter_B
.
load
(
warp_frag_B_
[(
iter_k
+
1
)
%
2
],
(
iter_k
+
1
)
%
ITER_K
);
//
warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
auto
warp_frag_A
=
warp_frag_A_
[
iter_k
%
2
];
//
auto warp_frag_A = warp_frag_A_[iter_k % 2];
auto
warp_frag_B
=
warp_frag_B_
[
iter_k
%
2
];
//
auto warp_frag_B = warp_frag_B_[iter_k % 2];
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
iter_m
=
0
;
iter_m
<
ITER_M
;
++
iter_m
)
{
//
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
iter_n
=
0
;
iter_n
<
ITER_N
;
++
iter_n
)
{
//
for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
auto
&
frag_A
=
warp_frag_A
[
iter_m
];
//
auto& frag_A = warp_frag_A[iter_m];
auto
&
frag_B
=
warp_frag_B
[
iter_n
];
//
auto& frag_B = warp_frag_B[iter_n];
auto
&
frag_C
=
frag_C_ptr
[
iter_n
*
ITER_M
+
iter_m
];
//
auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
mma_m16n8k16_row_col
(
frag_C
,
frag_A
,
frag_B
,
frag_C
);
//
mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
}
//
}
}
//
}
if
(
iter_k
<
ITER_K
-
1
)
{
//
if (iter_k < ITER_K - 1) {
iter_A
.
prefetch_batch
(
iter_k
,
kBatchA
,
gemm_iter
>
0
);
//
iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
iter_Q
.
prefetch_batch
(
iter_k
,
kBatchQ
,
gemm_iter
>
0
);
//
iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
iter_B
.
prefetch_batch
(
iter_k
,
kBatchB
,
gemm_iter
>
0
);
//
iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
}
//
}
if
(
iter_k
==
ITER_K
-
2
)
{
//
if (iter_k == ITER_K - 2) {
iter_A
.
prefetch_batch
(
iter_k
+
1
,
kBatchA
,
gemm_iter
>
0
);
//
iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
iter_Q
.
prefetch_batch
(
iter_k
+
1
,
kBatchQ
,
gemm_iter
>
0
);
//
iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
iter_B
.
prefetch_batch
(
iter_k
+
1
,
kBatchB
,
gemm_iter
>
0
);
//
iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
__pipeline_commit
();
//
__pipeline_commit();
__pipeline_wait_prior
(
STAGES
-
2
);
//
__pipeline_wait_prior(STAGES - 2);
sync_slice
(
slice_id
);
//
sync_slice(slice_id);
iter_A
.
next_stage
();
//
iter_A.next_stage();
iter_Q
.
next_stage
();
//
iter_Q.next_stage();
iter_B
.
next_stage
();
//
iter_B.next_stage();
warp_iter_A
.
next_stage
();
//
warp_iter_A.next_stage();
warp_iter_B
.
next_stage
();
//
warp_iter_B.next_stage();
--
gemm_iter
;
//
--gemm_iter;
}
//
}
}
//
}
}
}
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
...
@@ -235,35 +236,35 @@ struct Gemm {
...
@@ -235,35 +236,35 @@ struct Gemm {
__device__
void
sync_slice
(
int
slice_id
)
__device__
void
sync_slice
(
int
slice_id
)
{
{
if
constexpr
(
SLICES
==
1
)
{
//
if constexpr (SLICES == 1) {
__syncthreads
();
//
__syncthreads();
}
//
}
else
{
//
else {
constexpr
int
SLICE_GROUP
=
(
SLICES
+
7
)
/
8
;
//
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr
uint32_t
num_threads
=
kWarpCountMN
*
WARP_SIZE
;
//
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const
uint32_t
barrier_id
=
slice_id
/
SLICE_GROUP
+
1
;
//
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
// asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
//
// asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
//
}
}
}
__device__
void
load_partial
(
float
*
tb_frag_C
,
const
float
*
partial_C
,
int
cta
,
int
slice_id
)
__device__
void
load_partial
(
float
*
tb_frag_C
,
const
float
*
partial_C
,
int
cta
,
int
slice_id
)
{
{
if
(
slice_id
==
0
)
{
//
if (slice_id == 0) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
CTA_N
;
++
i
)
{
//
for (int i = 0; i < CTA_N; ++i) {
tb_frag_C
[
i
]
+=
partial_C
[
cta
*
CTA_N
*
CTA_M
+
i
*
CTA_M
+
threadIdx
.
x
];
//
tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
}
//
}
}
//
}
}
}
__device__
void
store_partial
(
float
*
partial_C
,
const
float
*
tb_frag_C
,
int
cta
,
int
slice_id
)
__device__
void
store_partial
(
float
*
partial_C
,
const
float
*
tb_frag_C
,
int
cta
,
int
slice_id
)
{
{
if
(
slice_id
==
0
)
{
//
if (slice_id == 0) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
CTA_N
;
++
i
)
{
//
for (int i = 0; i < CTA_N; ++i) {
partial_C
[
cta
*
CTA_N
*
CTA_M
+
i
*
CTA_M
+
threadIdx
.
x
]
=
tb_frag_C
[
i
];
//
partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
}
//
}
}
//
}
}
}
template
<
int
Index
>
template
<
int
Index
>
...
@@ -280,80 +281,80 @@ struct Gemm {
...
@@ -280,80 +281,80 @@ struct Gemm {
int
slice_id
)
int
slice_id
)
{
{
if
(
slice_id
!=
0
)
{
//
if (slice_id != 0) {
return
;
//
return;
}
//
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
//
//
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
WARP_N
/
OP_N
;
++
i
)
{
//
for (int i = 0; i < WARP_N / OP_N; ++i) {
const
float2
*
frag_C
=
(
float2
*
)
&
tb_frag_C
[
i
*
WARP_M
/
OP_M
*
4
];
//
const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
const
int
nn
=
cta_n
+
warp_id_n
*
WARP_N
+
i
*
OP_N
+
lane_id
/
4
;
//
const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
WARP_M
/
OP_M
;
++
j
)
{
//
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
x
=
0
;
x
<
2
;
++
x
)
{
//
for (int x = 0; x < 2; ++x) {
const
int
mm
=
cta_m
+
warp_id_m
*
WARP_M
+
j
*
OP_M
+
x
*
8
+
lane_id
%
4
*
2
;
//
const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// convert to half
//
// convert to half
half2
half_C
=
__float22half2_rn
(
frag_C
[
j
*
2
+
x
]);
//
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile
//
// transpose 8x8 accum tile
uint
trans_C
=
transpose_m8n8_b16
((
uint
&
)
half_C
,
lane_id
);
//
uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// store to global memory
//
// store to global memory
OutputOps
::
template
apply
<
Index
>(
trans_C
,
mm
,
nn
,
C
,
m
,
n
);
//
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
}
//
}
}
//
}
}
//
}
}
}
__device__
void
__device__
void
sum_slices
(
float
*
tb_frag_C
,
float
*
tb_smem_C
,
int
warp_id_m
,
int
warp_id_n
,
int
lane_id
,
int
slice_id
)
sum_slices
(
float
*
tb_frag_C
,
float
*
tb_smem_C
,
int
warp_id_m
,
int
warp_id_n
,
int
lane_id
,
int
slice_id
)
{
{
int
offset_m
=
warp_id_m
*
WARP_M
/
OP_M
;
//
int offset_m = warp_id_m * WARP_M / OP_M;
int
offset_n
=
warp_id_n
*
WARP_N
/
OP_N
;
//
int offset_n = warp_id_n * WARP_N / OP_N;
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
z
=
0
;
z
<
SLICES
;
++
z
)
{
//
for (int z = 0; z < SLICES; ++z) {
if
(
slice_id
==
z
)
{
//
if (slice_id == z) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
WARP_N
/
OP_N
;
++
i
)
{
//
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
WARP_M
/
OP_M
;
++
j
)
{
//
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
x
=
0
;
x
<
4
;
++
x
)
{
//
for (int x = 0; x < 4; ++x) {
int
src
=
(
i
*
WARP_M
/
OP_M
+
j
)
*
4
+
x
;
//
int src = (i * WARP_M / OP_M + j) * 4 + x;
int
dst
=
((
i
+
offset_n
)
*
CTA_M
/
OP_M
+
j
+
offset_m
)
*
4
+
x
;
//
int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
if
(
z
>
0
)
{
//
if (z > 0) {
using
namespace
ops
;
//
using namespace ops;
tb_frag_C
[
src
]
=
tb_smem_C
[
dst
*
WARP_SIZE
+
lane_id
]
+
tb_frag_C
[
src
];
//
tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
}
//
}
tb_smem_C
[
dst
*
WARP_SIZE
+
lane_id
]
=
tb_frag_C
[
src
];
//
tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
}
//
}
}
//
}
}
//
}
}
//
}
__syncthreads
();
//
__syncthreads();
}
//
}
if
(
slice_id
==
0
)
{
//
if (slice_id == 0) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
WARP_N
/
OP_N
;
++
i
)
{
//
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
WARP_M
/
OP_M
;
++
j
)
{
//
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
x
=
0
;
x
<
4
;
++
x
)
{
//
for (int x = 0; x < 4; ++x) {
int
src
=
((
i
+
offset_n
)
*
CTA_M
/
OP_M
+
j
+
offset_m
)
*
4
+
x
;
//
int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
int
dst
=
(
i
*
WARP_M
/
OP_M
+
j
)
*
4
+
x
;
//
int dst = (i * WARP_M / OP_M + j) * 4 + x;
tb_frag_C
[
dst
]
=
tb_smem_C
[
src
*
WARP_SIZE
+
lane_id
];
//
tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
}
//
}
}
//
}
}
//
}
}
//
}
}
}
Array
<
half
,
8
>
warp_frag_A_
[
2
][
WARP_M
/
OP_M
];
//
Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
Array
<
half
,
4
>
warp_frag_B_
[
2
][
WARP_N
/
OP_N
];
//
Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
__device__
void
run_v2
(
half
*
__restrict__
C
,
__device__
void
run_v2
(
half
*
__restrict__
C
,
const
uint
*
__restrict__
A
,
const
uint
*
__restrict__
A
,
...
@@ -364,89 +365,89 @@ struct Gemm {
...
@@ -364,89 +365,89 @@ struct Gemm {
int
K
,
int
K
,
int
output_op_idx
)
int
output_op_idx
)
{
{
static_assert
(
WARP_M
%
OP_N
==
0
);
//
static_assert(WARP_M % OP_N == 0);
float
tb_frag_C
[(
WARP_N
/
OP_N
)
*
(
WARP_M
/
OP_M
)
*
4
];
//
float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
extern
__shared__
uint8_t
smem
[];
//
extern __shared__ uint8_t smem[];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
//
const int warp_id = threadIdx.x / WARP_SIZE;
const
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
//
const int lane_id = threadIdx.x % WARP_SIZE;
const
int
warp_id_m
=
warp_id
%
kWarpCountM
;
//
const int warp_id_m = warp_id % kWarpCountM;
const
int
warp_id_nk
=
warp_id
/
kWarpCountM
;
//
const int warp_id_nk = warp_id / kWarpCountM;
const
int
warp_id_n
=
warp_id_nk
%
kWarpCountN
;
//
const int warp_id_n = warp_id_nk % kWarpCountN;
const
int
warp_id_k
=
warp_id_nk
/
kWarpCountN
;
//
const int warp_id_k = warp_id_nk / kWarpCountN;
const
int
warp_id_mn
=
warp_id_n
*
kWarpCountM
+
warp_id_m
;
//
const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
const
int
slice_id
=
warp_id_k
;
//
const int slice_id = warp_id_k;
const
int
cta_k
=
slice_id
*
SLICE_K
;
// sliced-k offset
//
const int cta_k = slice_id * SLICE_K; // sliced-k offset
const
int
cta_m
=
blockIdx
.
x
*
CTA_M
;
//
const int cta_m = blockIdx.x * CTA_M;
const
int
cta_n
=
blockIdx
.
y
*
CTA_N
;
//
const int cta_n = blockIdx.y * CTA_N;
// each slice has its own partition of smem
//
//
each slice has its own partition of smem
uint4
*
const
tb_smem_A
=
(
uint4
*
)(
smem
+
IteratorA
::
kSmemByteSize
*
slice_id
);
//
uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
half
*
const
tb_smem_B
=
(
half
*
)(
smem
+
IteratorA
::
kSmemByteSize
*
SLICES
+
IteratorB
::
kSmemByteSize
*
slice_id
);
//
half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
//
//
[CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
float
*
const
tb_smem_C
=
(
float
*
)
smem
;
//
float* const tb_smem_C = (float*)smem;
__shared__
typename
IteratorQ
::
Storage
tb_smem_Q_storage
;
//
__shared__ typename IteratorQ::Storage tb_smem_Q_storage;
auto
tb_smem_Q
=
tb_smem_Q_storage
.
data
[
slice_id
];
//
auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
IteratorA
iter_A
{
A
,
tb_smem_A
,
M
,
K
,
cta_m
,
cta_k
,
warp_id_mn
,
lane_id
};
//
IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorQ
iter_Q
{
Q
,
tb_smem_Q
,
M
,
K
,
cta_m
,
cta_k
,
warp_id_mn
,
lane_id
};
//
IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorB
iter_B
{
B
,
tb_smem_B
,
K
,
N
,
cta_n
,
cta_k
,
warp_id_mn
,
lane_id
};
//
IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
const
int
offset_m
=
warp_id_m
*
WARP_M
+
lane_id
;
//
const int offset_m = warp_id_m * WARP_M + lane_id;
WarpIterA
warp_iter_A
(
iter_A
.
smem_
,
iter_Q
.
smem_
,
warp_id
,
lane_id
,
offset_m
,
cta_k
);
//
WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
WarpIterB
warp_iter_B
(
iter_B
.
smem_int_ptr_
,
warp_id_n
,
lane_id
,
0
);
//
WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
int
gemm_iter
=
(
K
+
CTA_K
-
1
)
/
CTA_K
;
//
int gemm_iter = (K + CTA_K - 1) / CTA_K;
PRAGMA_UNROLL
//
PRAGMA_UNROLL
for
(
int
stage
=
0
;
stage
<
STAGES
-
1
;
++
stage
,
--
gemm_iter
)
{
//
for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
iter_A
.
prefetch_stage
(
gemm_iter
>
0
);
//
iter_A.prefetch_stage(gemm_iter > 0);
iter_Q
.
prefetch_stage
(
gemm_iter
>
0
);
//
iter_Q.prefetch_stage(gemm_iter > 0);
iter_B
.
prefetch_stage
(
gemm_iter
>
0
);
//
iter_B.prefetch_stage(gemm_iter > 0);
__pipeline_commit
();
//
__pipeline_commit();
}
//
}
clear
(
tb_frag_C
);
//
clear(tb_frag_C);
__pipeline_wait_prior
(
STAGES
-
2
);
//
__pipeline_wait_prior(STAGES - 2);
sync_slice
(
slice_id
);
//
sync_slice(slice_id);
warp_iter_A
.
load
(
warp_frag_A_
[
0
],
0
);
//
warp_iter_A.load(warp_frag_A_[0], 0);
warp_iter_B
.
load
(
warp_frag_B_
[
0
],
0
);
//
warp_iter_B.load(warp_frag_B_[0], 0);
PRAGMA_NO_UNROLL
//
PRAGMA_NO_UNROLL
for
(;
gemm_iter
>
-
STAGES
+
1
;)
{
//
for (; gemm_iter > -STAGES + 1;) {
warp_mma
(
iter_A
,
iter_Q
,
iter_B
,
warp_iter_A
,
warp_iter_B
,
tb_frag_C
,
slice_id
,
gemm_iter
);
//
warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
}
//
}
__pipeline_commit
();
//
__pipeline_commit();
__pipeline_wait_prior
(
0
);
//
__pipeline_wait_prior(0);
__syncthreads
();
//
__syncthreads();
if
constexpr
(
SLICES
>
1
)
{
//
if constexpr (SLICES > 1) {
sum_slices
(
tb_frag_C
,
tb_smem_C
,
warp_id_m
,
warp_id_n
,
lane_id
,
slice_id
);
//
sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
}
//
}
switch
(
output_op_idx
)
{
//
switch (output_op_idx) {
case
0
:
//
case 0:
store_accum
<
0
>
(
tb_frag_C
,
tb_smem_C
,
C
,
M
,
N
,
cta_m
,
cta_n
,
warp_id_m
,
warp_id_n
,
lane_id
,
slice_id
);
//
store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break
;
//
break;
case
1
:
//
case 1:
store_accum
<
1
>
(
tb_frag_C
,
tb_smem_C
,
C
,
M
,
N
,
cta_m
,
cta_n
,
warp_id_m
,
warp_id_n
,
lane_id
,
slice_id
);
//
store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break
;
//
break;
default:
//
default:
return
;
//
return;
}
//
}
}
}
};
};
...
...
src/turbomind/models/llama/BlockManager.cc
View file @
d26f4c73
...
@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
...
@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
return
false
;
return
false
;
}
}
//auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
auto
ptr
=
(
uint8_t
*
)
allocator_
->
malloc
(
block_size_
*
chunk_size
);
auto
ptr
=
(
uint8_t
*
)
allocator_
->
malloc
(
block_size_
*
chunk_size
);
if
(
!
ptr
)
{
if
(
!
ptr
)
{
return
false
;
return
false
;
...
...
src/turbomind/models/llama/CMakeLists.txt
View file @
d26f4c73
...
@@ -19,13 +19,14 @@ add_library(Llama STATIC
...
@@ -19,13 +19,14 @@ add_library(Llama STATIC
unified_attention_layer.cc
unified_attention_layer.cc
llama_kernels.cu
llama_kernels.cu
llama_decoder_kernels.cu
llama_decoder_kernels.cu
llama_utils.cu
)
llama_utils.cu
./awq_sugon/gemm_w4_dequation.cu
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries
(
Llama PUBLIC cudart
target_link_libraries
(
Llama PUBLIC cudart
#
gemm_s4_f16
gemm_s4_f16
cublasMMWrapper
cublasMMWrapper
DynamicDecodeLayer
DynamicDecodeLayer
activation_kernels
activation_kernels
...
@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart
...
@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart
memory_utils
memory_utils
nccl_utils
nccl_utils
cuda_utils
cuda_utils
logger
)
logger
gemm_multiB_int4
)
# llama_fmha)
# llama_fmha)
if
(
NOT MSVC
)
if
(
NOT MSVC
)
...
...
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
View file @
d26f4c73
...
@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
...
@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t
inter_size
,
size_t
inter_size
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
bool
attn_bias
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
)
:
size_t
tensor_para_rank
)
:
...
@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
...
@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights
.
qkv
.
output_dims
=
(
head_num
+
2
*
kv_head_num
)
*
size_per_head
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
output_dims
=
(
head_num
+
2
*
kv_head_num
)
*
size_per_head
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
qkv
.
group_size
=
group_size
;
self_attn_weights
.
qkv
.
group_size
=
group_size
;
self_attn_weights
.
qkv
.
w4_weight_layout
=
w4_weight_layout
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
type
=
weight_type
;
self_attn_weights
.
output
.
type
=
weight_type
;
self_attn_weights
.
output
.
group_size
=
group_size
;
self_attn_weights
.
output
.
group_size
=
group_size
;
self_attn_weights
.
output
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
gating
.
group_size
=
group_size
;
ffn_weights
.
gating
.
group_size
=
group_size
;
ffn_weights
.
gating
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
group_size
=
group_size
;
ffn_weights
.
intermediate
.
group_size
=
group_size
;
ffn_weights
.
intermediate
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
fused_gating_intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
fused_gating_intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
fused_gating_intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
*
2
;
ffn_weights
.
fused_gating_intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
*
2
;
ffn_weights
.
fused_gating_intermediate
.
type
=
weight_type
;
ffn_weights
.
fused_gating_intermediate
.
type
=
weight_type
;
ffn_weights
.
fused_gating_intermediate
.
group_size
=
group_size
;
ffn_weights
.
fused_gating_intermediate
.
group_size
=
group_size
;
ffn_weights
.
fused_gating_intermediate
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
output
.
group_size
=
group_size
;
ffn_weights
.
output
.
group_size
=
group_size
;
ffn_weights
.
output
.
w4_weight_layout
=
w4_weight_layout
;
mallocWeights
();
mallocWeights
();
}
}
...
@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
...
@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else
{
// int8, int4
else
{
// int8, int4
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
FT_CHECK
(
weights
.
input_dims
%
factor
==
0
);
FT_CHECK
(
weights
.
input_dims
%
factor
==
0
);
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// //读环境变量
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// int m_weightlayout_switch=1;
// interleaved scales/zeros
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if
((
weights
.
input_dims
%
4096
==
0
)
&&
(
weights
.
w4_weight_layout
==
1
||
weights
.
w4_weight_layout
==
2
))
{
size_t
new_input_dims
=
weights
.
input_dims
+
2
*
weights
.
group_size
;
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
new_input_dims
*
weights
.
output_dims
/
factor
);
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
new_input_dims
*
weights
.
output_dims
/
factor
);
// interleaved scales/zeros
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
new_input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
}
else
{
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// interleaved scales/zeros
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
}
}
}
}
}
...
@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
...
@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
}
}
else
{
// int8, int4
else
{
// int8, int4
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
output
.
insert
(
get_name
(
"qweight"
),
// //读环境变量
Tensor
{
MEMORY_GPU
,
// int m_weightlayout_switch=1;
TYPE_INT32
,
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
{
weights
.
input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
// if (env_weightlayout_str != nullptr) {
weights
.
kernel
});
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
output
.
insert
(
get_name
(
"scales_zeros"
),
// }
Tensor
{
MEMORY_GPU
,
if
((
weights
.
input_dims
%
4096
==
0
)
&&
(
weights
.
w4_weight_layout
==
1
||
weights
.
w4_weight_layout
==
2
))
getTensorType
<
T
>
(),
{
{
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
size_t
new_input_dims
=
weights
.
input_dims
+
weights
.
group_size
;
weights
.
scales_and_zeros
});
output
.
insert
(
get_name
(
"qweight"
),
Tensor
{
MEMORY_GPU
,
TYPE_INT32
,
{
new_input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
weights
.
kernel
});
output
.
insert
(
get_name
(
"scales_zeros"
),
Tensor
{
MEMORY_GPU
,
getTensorType
<
T
>
(),
{
new_input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
weights
.
scales_and_zeros
});
}
else
{
output
.
insert
(
get_name
(
"qweight"
),
Tensor
{
MEMORY_GPU
,
TYPE_INT32
,
{
weights
.
input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
weights
.
kernel
});
output
.
insert
(
get_name
(
"scales_zeros"
),
Tensor
{
MEMORY_GPU
,
getTensorType
<
T
>
(),
{
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
weights
.
scales_and_zeros
});
}
}
}
}
}
...
@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
...
@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
FT_CHECK
(
dim1
%
factor
==
0
);
FT_CHECK
(
dim1
%
factor
==
0
);
std
::
vector
<
size_t
>
w_shape
{
dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
// //读环境变量
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if
((
dim0
%
4096
==
0
)
&&
(
w
.
w4_weight_layout
==
1
||
w
.
w4_weight_layout
==
2
))
{
size_t
new_dim0
=
dim0
+
2
*
w
.
group_size
;
std
::
vector
<
size_t
>
w_shape
{
new_dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
const
size_t
group_count
=
w
.
group_size
>
0
?
new_dim0
/
w
.
group_size
:
1
;
const
size_t
group_count
=
w
.
group_size
>
0
?
dim0
/
w
.
group_size
:
1
;
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
}
else
{
std
::
vector
<
size_t
>
w_shape
{
dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
const
size_t
group_count
=
w
.
group_size
>
0
?
dim0
/
w
.
group_size
:
1
;
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
}
}
}
}
}
...
...
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
View file @
d26f4c73
...
@@ -35,6 +35,7 @@ public:
...
@@ -35,6 +35,7 @@ public:
size_t
inter_size
,
size_t
inter_size
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
bool
attn_bias
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
);
size_t
tensor_para_rank
);
...
...
src/turbomind/models/llama/LlamaDenseWeight.h
View file @
d26f4c73
...
@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
...
@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
T
*
bias
;
T
*
bias
;
T
*
scales_and_zeros
;
T
*
scales_and_zeros
;
int
group_size
;
int
group_size
;
int
w4_weight_layout
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
src/turbomind/models/llama/LlamaFfnLayer.cc
View file @
d26f4c73
...
@@ -29,7 +29,7 @@ namespace turbomind {
...
@@ -29,7 +29,7 @@ namespace turbomind {
template
<
typename
T
>
template
<
typename
T
>
void
LlamaFfnLayer
<
T
>::
allocateBuffer
(
size_t
token_num
)
void
LlamaFfnLayer
<
T
>::
allocateBuffer
(
size_t
token_num
)
{
{
inter_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
inter_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
inter_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
inter_buf_
,
2
*
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
gating_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
gating_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
gating_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
gating_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
is_allocate_buffer_
=
true
;
is_allocate_buffer_
=
true
;
}
}
...
@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
...
@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
if
(
weights
->
fused_gating_intermediate
.
kernel
)
{
if
(
weights
->
fused_gating_intermediate
.
kernel
)
{
NvtxScope
scope
(
"fused_silu_ffn"
);
NvtxScope
scope
(
"fused_silu_ffn"
);
linear_
.
forward
(
// linear_.forward(
gating_buf_
,
ffn_input_data
,
num_token
,
weights
->
fused_gating_intermediate
,
LlamaLinear
<
T
>::
kFusedSiluFfn
);
// gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
linear_
.
forward_ffn
(
gating_buf_
,
inter_buf_
,
ffn_input_data
,
num_token
,
weights
->
fused_gating_intermediate
,
LlamaLinear
<
T
>::
kFusedSiluFfn
);
}
}
else
{
else
{
{
// w1(x)
{
// w1(x)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment