Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
d26f4c73
"vscode:/vscode.git/clone" did not exist on "f08513d4066a629f58bc3c392821cc9a9183d41f"
Commit
d26f4c73
authored
May 27, 2024
by
gaoqiong
Browse files
增加awq模块
parent
2326380c
Changes
32
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
599 additions
and
289 deletions
+599
-289
CMakeLists.txt
CMakeLists.txt
+1
-1
lmdeploy/cli/cli.py
lmdeploy/cli/cli.py
+5
-1
lmdeploy/turbomind/deploy/converter.py
lmdeploy/turbomind/deploy/converter.py
+5
-0
lmdeploy/turbomind/deploy/target_model/base.py
lmdeploy/turbomind/deploy/target_model/base.py
+12
-0
lmdeploy/turbomind/deploy/target_model/w4.py
lmdeploy/turbomind/deploy/target_model/w4.py
+85
-7
lmdeploy/turbomind/turbomind.py
lmdeploy/turbomind/turbomind.py
+4
-0
setup.py
setup.py
+12
-0
src/turbomind/kernels/CMakeLists.txt
src/turbomind/kernels/CMakeLists.txt
+1
-1
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
+3
-1
src/turbomind/kernels/gemm_s_f16/common.h
src/turbomind/kernels/gemm_s_f16/common.h
+37
-39
src/turbomind/kernels/gemm_s_f16/format.cu
src/turbomind/kernels/gemm_s_f16/format.cu
+105
-1
src/turbomind/kernels/gemm_s_f16/format.h
src/turbomind/kernels/gemm_s_f16/format.h
+11
-0
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
+4
-0
src/turbomind/kernels/gemm_s_f16/gemm_template.h
src/turbomind/kernels/gemm_s_f16/gemm_template.h
+215
-214
src/turbomind/models/llama/BlockManager.cc
src/turbomind/models/llama/BlockManager.cc
+1
-0
src/turbomind/models/llama/CMakeLists.txt
src/turbomind/models/llama/CMakeLists.txt
+5
-3
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
+85
-18
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
+1
-0
src/turbomind/models/llama/LlamaDenseWeight.h
src/turbomind/models/llama/LlamaDenseWeight.h
+1
-0
src/turbomind/models/llama/LlamaFfnLayer.cc
src/turbomind/models/llama/LlamaFfnLayer.cc
+6
-3
No files found.
CMakeLists.txt
View file @
d26f4c73
...
@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
...
@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
# $<TARGET_OBJECTS:flash_attention2>
# $<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:LlamaTritonBackend>
#
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
$<TARGET_OBJECTS:TransformerTritonBackend>
...
...
lmdeploy/cli/cli.py
View file @
d26f4c73
...
@@ -61,7 +61,11 @@ class CLI(object):
...
@@ -61,7 +61,11 @@ class CLI(object):
default
=
0
,
default
=
0
,
help
=
'A parameter used in awq to quantize fp16 weights '
help
=
'A parameter used in awq to quantize fp16 weights '
'to 4 bits'
)
'to 4 bits'
)
parser
.
add_argument
(
'--w4-weight-layout'
,
type
=
int
,
default
=
2
,
help
=
'A parameter used in AWQ to control the layout of weight '
)
parser
.
set_defaults
(
run
=
CLI
.
convert
)
parser
.
set_defaults
(
run
=
CLI
.
convert
)
@
staticmethod
@
staticmethod
...
...
lmdeploy/turbomind/deploy/converter.py
View file @
d26f4c73
...
@@ -196,6 +196,7 @@ def main(model_name: str,
...
@@ -196,6 +196,7 @@ def main(model_name: str,
tp
:
int
=
1
,
tp
:
int
=
1
,
quant_path
:
str
=
None
,
quant_path
:
str
=
None
,
group_size
:
int
=
0
,
group_size
:
int
=
0
,
w4_weight_layout
:
int
=
2
,
**
kwargs
):
**
kwargs
):
"""deploy llama family models via turbomind.
"""deploy llama family models via turbomind.
...
@@ -215,6 +216,7 @@ def main(model_name: str,
...
@@ -215,6 +216,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None.
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
to 4 bits
w4_weight_layout (int) :a parameter used in AWQ to control the layout of weight
kwargs (dict): other params for convert
kwargs (dict): other params for convert
"""
"""
...
@@ -260,10 +262,13 @@ def main(model_name: str,
...
@@ -260,10 +262,13 @@ def main(model_name: str,
cfg
.
tensor_para_size
=
tp
cfg
.
tensor_para_size
=
tp
cfg
.
rotary_embedding
=
cfg
.
size_per_head
cfg
.
rotary_embedding
=
cfg
.
size_per_head
cfg
.
group_size
=
group_size
cfg
.
group_size
=
group_size
cfg
.
w4_weight_layout
=
w4_weight_layout
if
inferred_model_format
.
find
(
'awq'
)
!=
-
1
:
if
inferred_model_format
.
find
(
'awq'
)
!=
-
1
:
cfg
.
weight_type
=
'int4'
cfg
.
weight_type
=
'int4'
output_format
=
'w4'
output_format
=
'w4'
assert
group_size
>
0
,
f
'group_size:
{
group_size
}
should > 0'
assert
group_size
>
0
,
f
'group_size:
{
group_size
}
should > 0'
print
(
"w4_weight_layout:"
,
w4_weight_layout
)
assert
w4_weight_layout
>=
0
and
w4_weight_layout
<
3
,
f
'w4_weight_layout:
{
w4_weight_layout
}
should >= 0 and < 3'
else
:
else
:
#output_format = update_output_format(model_name, inferred_model_format,
#output_format = update_output_format(model_name, inferred_model_format,
# model_path, output_format)
# model_path, output_format)
...
...
lmdeploy/turbomind/deploy/target_model/base.py
View file @
d26f4c73
...
@@ -5,6 +5,7 @@ import inspect
...
@@ -5,6 +5,7 @@ import inspect
import
io
import
io
import
json
import
json
import
os.path
as
osp
import
os.path
as
osp
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
configparser
import
ConfigParser
from
configparser
import
ConfigParser
...
@@ -52,6 +53,7 @@ class TurbomindModelConfig:
...
@@ -52,6 +53,7 @@ class TurbomindModelConfig:
rope_theta
:
float
=
10000.0
rope_theta
:
float
=
10000.0
size_per_head
:
int
=
128
size_per_head
:
int
=
128
group_size
:
int
=
0
group_size
:
int
=
0
w4_weight_layout
:
int
=
2
max_batch_size
:
int
=
64
max_batch_size
:
int
=
64
max_context_token_num
:
int
=
1
max_context_token_num
:
int
=
1
step_length
:
int
=
1
step_length
:
int
=
1
...
@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
...
@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
self
.
to_file
=
to_file
self
.
to_file
=
to_file
self
.
out_dir
=
out_dir
self
.
out_dir
=
out_dir
self
.
tm_params
=
{}
self
.
tm_params
=
{}
#self.weight_layout= 1
#获取环境变量
#env_weight_layout = os.environ.get('LMDEPLOY_WEIGHTLAYOUT_SWITCH', '1')
#self.weight_layout =int(env_weight_layout)
#print("self.weight_layout:",self.weight_layout)
@
abstractmethod
@
abstractmethod
def
get_config
(
self
,
cfg
:
TurbomindModelConfig
)
->
TurbomindModelConfig
:
def
get_config
(
self
,
cfg
:
TurbomindModelConfig
)
->
TurbomindModelConfig
:
...
@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
...
@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
return
x
.
view
(
n_heads
,
2
,
dim
//
n_heads
//
2
,
return
x
.
view
(
n_heads
,
2
,
dim
//
n_heads
//
2
,
1
).
transpose
(
1
,
2
).
reshape
(
dim
,
1
)
1
).
transpose
(
1
,
2
).
reshape
(
dim
,
1
)
def
permute_trans
(
x
:
torch
.
Tensor
):
if
x
.
shape
[
-
1
]
>
1
:
dim
=
x
.
shape
[
-
1
]
return
x
.
view
(
-
1
,
x
.
shape
[
-
1
]).
transpose
(
0
,
1
).
reshape
(
dim
,
-
1
)
def
merge_qkv
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
tp
:
int
,
def
merge_qkv
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
tp
:
int
,
dim
:
int
):
dim
:
int
):
...
...
lmdeploy/turbomind/deploy/target_model/w4.py
View file @
d26f4c73
...
@@ -8,7 +8,7 @@ import lmdeploy
...
@@ -8,7 +8,7 @@ import lmdeploy
from
..source_model.base
import
BaseInputModel
,
BaseReader
from
..source_model.base
import
BaseInputModel
,
BaseReader
from
.base
import
(
OUTPUT_MODELS
,
BaseOutputModel
,
TurbomindModelConfig
,
from
.base
import
(
OUTPUT_MODELS
,
BaseOutputModel
,
TurbomindModelConfig
,
merge_qkv
,
permute
)
merge_qkv
,
permute
,
permute_trans
)
# import _turbomind as _tm
# import _turbomind as _tm
# TODO: find another way import _turbomind
# TODO: find another way import _turbomind
...
@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
...
@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
return
_qw
,
_sz
return
_qw
,
_sz
def
convert_s4_
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
assert
qw
.
is_contiguous
()
assert
qz
.
is_contiguous
()
assert
s
.
is_contiguous
()
_qw
=
torch
.
zeros_like
(
qw
)
_sz
=
torch
.
zeros_like
(
s
,
dtype
=
torch
.
int32
)
# half2
_ws
=
torch
.
zeros_like
(
s
)
_tm
.
convert_s4_k_m8_
(
_qw
,
_sz
,
_ws
,
qw
,
s
,
qz
,
qw
.
size
(
-
1
)
*
8
,
qw
.
size
(
0
),
group_size
)
return
_qw
,
_sz
def
tp_m_s4
(
x
:
torch
.
Tensor
,
tp
:
int
):
def
tp_m_s4
(
x
:
torch
.
Tensor
,
tp
:
int
):
return
x
.
view
(
x
.
size
(
0
)
//
32
,
tp
,
-
1
,
128
).
permute
(
0
,
2
,
3
,
return
x
.
view
(
x
.
size
(
0
)
//
32
,
tp
,
-
1
,
128
).
permute
(
0
,
2
,
3
,
...
@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
"""Export transformer layer i."""
"""Export transformer layer i."""
group_size
=
self
.
cfg
.
group_size
group_size
=
self
.
cfg
.
group_size
tp
=
self
.
cfg
.
tensor_para_size
tp
=
self
.
cfg
.
tensor_para_size
w4_weight_layout
=
self
.
cfg
.
w4_weight_layout
size_per_head
=
self
.
cfg
.
size_per_head
size_per_head
=
self
.
cfg
.
size_per_head
# attn
# attn
q_qw
,
k_qw
,
v_qw
,
o_qw
=
get_cuda_tensor
(
bin
.
attn
(
i
))
q_qw
,
k_qw
,
v_qw
,
o_qw
=
get_cuda_tensor
(
bin
.
attn
(
i
))
...
@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
qkv_qz
=
merge_qkv
(
q_qz
,
k_qz
,
v_qz
,
tp
,
dim
=
2
)
qkv_qz
=
merge_qkv
(
q_qz
,
k_qz
,
v_qz
,
tp
,
dim
=
2
)
qkv_s
=
merge_qkv
(
q_s
,
k_s
,
v_s
,
tp
,
dim
=
2
)
qkv_s
=
merge_qkv
(
q_s
,
k_s
,
v_s
,
tp
,
dim
=
2
)
qkv_qw
,
qkv_sz
=
convert_s4
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
pad_group_count
=
2
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
if
qkv_qw
.
shape
[
0
]
%
4096
==
0
:
qkv_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
qkv_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
qkv_qw
=
torch
.
cat
((
qkv_qw
,
qkv_qw_padding
),
dim
=
0
).
contiguous
()
qkv_qz_padding
=
torch
.
zeros
(
pad_group_count
,
qkv_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
qkv_qz
=
torch
.
cat
((
qkv_qz
,
qkv_qz_padding
),
dim
=
0
).
contiguous
()
qkv_s_padding
=
torch
.
zeros
(
pad_group_count
,
qkv_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
qkv_s
=
torch
.
cat
((
qkv_s
,
qkv_s_padding
),
dim
=
0
).
contiguous
()
qkv_qw
,
qkv_sz
=
convert_s4_
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
qkv_sz
=
permute_trans
(
qkv_sz
)
else
:
qkv_qw
,
qkv_sz
=
convert_s4
(
qkv_qw
,
qkv_qz
,
qkv_s
,
group_size
)
qkv_qw
=
tp_m_s4
(
qkv_qw
,
tp
)
#print("请设置weight layout\n")
self
.
save_split
(
qkv_qw
,
f
'layers.
{
i
}
.attention.w_qkv.qweight'
,
-
1
)
self
.
save_split
(
qkv_qw
,
f
'layers.
{
i
}
.attention.w_qkv.qweight'
,
-
1
)
self
.
save_split
(
qkv_sz
,
f
'layers.
{
i
}
.attention.w_qkv.scales_zeros'
,
-
1
)
self
.
save_split
(
qkv_sz
,
f
'layers.
{
i
}
.attention.w_qkv.scales_zeros'
,
-
1
)
o_qw
,
o_sz
=
convert_s4
(
o_qw
,
o_qz
,
o_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
if
o_qw
.
shape
[
0
]
%
4096
==
0
:
o_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
o_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
o_qw
=
torch
.
cat
((
o_qw
,
o_qw_padding
),
dim
=
0
).
contiguous
()
o_qz_padding
=
torch
.
zeros
(
pad_group_count
,
o_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
o_qz
=
torch
.
cat
((
o_qz
,
o_qz_padding
),
dim
=
0
).
contiguous
()
o_s_padding
=
torch
.
zeros
(
pad_group_count
,
o_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
o_s
=
torch
.
cat
((
o_s
,
o_s_padding
),
dim
=
0
).
contiguous
()
o_qw
,
o_sz
=
convert_s4_
(
o_qw
,
o_qz
,
o_s
,
group_size
)
o_sz
=
permute_trans
(
o_sz
)
else
:
o_qw
,
o_sz
=
convert_s4
(
o_qw
,
o_qz
,
o_s
,
group_size
)
self
.
save_split
(
o_qw
,
f
'layers.
{
i
}
.attention.wo.qweight'
,
0
)
self
.
save_split
(
o_qw
,
f
'layers.
{
i
}
.attention.wo.qweight'
,
0
)
self
.
save_split
(
o_sz
,
f
'layers.
{
i
}
.attention.wo.scales_zeros'
,
0
)
self
.
save_split
(
o_sz
,
f
'layers.
{
i
}
.attention.wo.scales_zeros'
,
0
)
...
@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
...
@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
w13_qw
,
w13_qz
,
w13_s
=
fuse_w1_w3_s4
(
w1_qw
,
w1_qz
,
w1_s
,
w3_qw
,
w3_qz
,
w13_qw
,
w13_qz
,
w13_s
=
fuse_w1_w3_s4
(
w1_qw
,
w1_qz
,
w1_s
,
w3_qw
,
w3_qz
,
w3_s
)
w3_s
)
w13_qw
,
w13_sz
=
convert_s4
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
if
w13_qw
.
shape
[
0
]
%
4096
==
0
:
w13_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
w13_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w13_qw
=
torch
.
cat
((
w13_qw
,
w13_qw_padding
),
dim
=
0
).
contiguous
()
w13_qz_padding
=
torch
.
zeros
(
pad_group_count
,
w13_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w13_qz
=
torch
.
cat
((
w13_qz
,
w13_qz_padding
),
dim
=
0
).
contiguous
()
w13_s_padding
=
torch
.
zeros
(
pad_group_count
,
w13_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
w13_s
=
torch
.
cat
((
w13_s
,
w13_s_padding
),
dim
=
0
).
contiguous
()
w13_qw
,
w13_sz
=
convert_s4_
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
w13_sz
=
permute_trans
(
w13_sz
)
else
:
w13_qw
,
w13_sz
=
convert_s4
(
w13_qw
,
w13_qz
,
w13_s
,
group_size
)
w13_qw
=
tp_m_s4
(
w13_qw
,
tp
)
self
.
save_split
(
w13_qw
,
f
'layers.
{
i
}
.feed_forward.w13.qweight'
,
-
1
)
self
.
save_split
(
w13_qw
,
f
'layers.
{
i
}
.feed_forward.w13.qweight'
,
-
1
)
self
.
save_split
(
w13_sz
,
f
'layers.
{
i
}
.feed_forward.w13.scales_zeros'
,
self
.
save_split
(
w13_sz
,
f
'layers.
{
i
}
.feed_forward.w13.scales_zeros'
,
-
1
)
-
1
)
w2_qw
,
w2_sz
=
convert_s4
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
if
w4_weight_layout
==
1
or
w4_weight_layout
==
2
:
#pading
if
w2_qw
.
shape
[
0
]
%
4096
==
0
:
w2_qw_padding
=
torch
.
zeros
(
group_size
*
pad_group_count
,
w2_qw
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w2_qw
=
torch
.
cat
((
w2_qw
,
w2_qw_padding
),
dim
=
0
).
contiguous
()
w2_qz_padding
=
torch
.
zeros
(
pad_group_count
,
w2_qz
.
shape
[
1
],
dtype
=
torch
.
int32
).
cuda
()
w2_qz
=
torch
.
cat
((
w2_qz
,
w2_qz_padding
),
dim
=
0
).
contiguous
()
w2_s_padding
=
torch
.
zeros
(
pad_group_count
,
w2_s
.
shape
[
1
],
dtype
=
torch
.
float16
).
cuda
()
w2_s
=
torch
.
cat
((
w2_s
,
w2_s_padding
),
dim
=
0
).
contiguous
()
w2_qw
,
w2_sz
=
convert_s4_
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
w2_sz
=
permute_trans
(
w2_sz
)
else
:
w2_qw
,
w2_sz
=
convert_s4
(
w2_qw
,
w2_qz
,
w2_s
,
group_size
)
self
.
save_split
(
w2_qw
,
f
'layers.
{
i
}
.feed_forward.w2.qweight'
,
0
)
self
.
save_split
(
w2_qw
,
f
'layers.
{
i
}
.feed_forward.w2.qweight'
,
0
)
self
.
save_split
(
w2_sz
,
f
'layers.
{
i
}
.feed_forward.w2.scales_zeros'
,
0
)
self
.
save_split
(
w2_sz
,
f
'layers.
{
i
}
.feed_forward.w2.scales_zeros'
,
0
)
...
...
lmdeploy/turbomind/turbomind.py
View file @
d26f4c73
...
@@ -147,6 +147,7 @@ class TurboMind:
...
@@ -147,6 +147,7 @@ class TurboMind:
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_format
:
Optional
[
str
]
=
None
,
model_format
:
Optional
[
str
]
=
None
,
group_size
:
Optional
[
int
]
=
None
,
group_size
:
Optional
[
int
]
=
None
,
w4_weight_layout
:
Optional
[
int
]
=
None
,
tp
:
Optional
[
int
]
=
None
,
tp
:
Optional
[
int
]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -179,6 +180,7 @@ class TurboMind:
...
@@ -179,6 +180,7 @@ class TurboMind:
engine_config
=
_update_engine_config
(
engine_config
,
engine_config
=
_update_engine_config
(
engine_config
,
model_format
=
model_format
,
model_format
=
model_format
,
group_size
=
group_size
,
group_size
=
group_size
,
w4_weight_layout
=
w4_weight_layout
,
tp
=
tp
,
tp
=
tp
,
**
kwargs
)
**
kwargs
)
...
@@ -304,6 +306,7 @@ class TurboMind:
...
@@ -304,6 +306,7 @@ class TurboMind:
output_format
=
'w4'
output_format
=
'w4'
data_type
=
'int4'
data_type
=
'int4'
cfg
.
group_size
=
128
cfg
.
group_size
=
128
cfg
.
w4_weight_layout
=
2
else
:
else
:
# output_format = update_output_format(cfg.model_name,
# output_format = update_output_format(cfg.model_name,
# inferred_model_format,
# inferred_model_format,
...
@@ -378,6 +381,7 @@ class TurboMind:
...
@@ -378,6 +381,7 @@ class TurboMind:
self
.
config
=
cfg
self
.
config
=
cfg
self
.
model_name
=
cfg
.
model_name
self
.
model_name
=
cfg
.
model_name
self
.
data_type
=
cfg
.
weight_type
self
.
data_type
=
cfg
.
weight_type
print
(
"from_workspace_cfg:"
,
cfg
)
# create model
# create model
logger
.
warning
(
f
'model_config:
\n\n
{
cfg
.
toini
()
}
'
)
logger
.
warning
(
f
'model_config:
\n\n
{
cfg
.
toini
()
}
'
)
...
...
setup.py
View file @
d26f4c73
...
@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
file
.
writelines
(
lines
)
file
.
writelines
(
lines
)
file
.
close
()
file
.
close
()
def
copy_ck_so
():
lmdeploy_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
so_path
=
os
.
path
.
join
(
os
.
path
.
join
(
lmdeploy_root
,
"3rdparty"
),
"libgemm_multiB_int4.so"
)
# dtk version
if
os
.
getenv
(
"ROCM_PATH"
):
rocm_path
=
os
.
getenv
(
'ROCM_PATH'
,
""
)
rocm_so_path
=
os
.
path
.
join
(
rocm_path
,
'lib'
)
print
(
"rocm_so_path:"
,
rocm_so_path
)
shutil
.
copy
(
so_path
,
rocm_so_path
)
else
:
shutil
.
copy
(
so_path
,
"usr/local/lib"
)
def
get_version
():
def
get_version
():
copy_ck_so
()
get_version_add
()
get_version_add
()
version_file
=
'lmdeploy/version.py'
version_file
=
'lmdeploy/version.py'
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
...
...
src/turbomind/kernels/CMakeLists.txt
View file @
d26f4c73
...
@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
...
@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#
add_subdirectory(gemm_s_f16)
add_subdirectory
(
gemm_s_f16
)
add_subdirectory
(
decoder_multihead_attention
)
add_subdirectory
(
decoder_multihead_attention
)
src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
View file @
d26f4c73
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
add_library
(
gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu
)
add_library
(
gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu
../../models/llama/awq_sugon/gemm_w4_dequation.cu
)
target_compile_options
(
gemm_s4_f16 PRIVATE
target_compile_options
(
gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr
)
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr
)
set_property
(
TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON
)
set_property
(
TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON
)
...
...
src/turbomind/kernels/gemm_s_f16/common.h
View file @
d26f4c73
...
@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0])
// : "=r"(h[0])
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
h
[
0
]
=
(
i4s
&
BOTTOM_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1])
// : "=r"(h[1])
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
h
[
1
]
=
(
i4s
&
TOP_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2])
// : "=r"(h[2])
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
h
[
2
]
=
(
top_i4s
&
BOTTOM_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3])
// : "=r"(h[3])
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf
(
"=========common.h 86
\n
"
);
h
[
3
]
=
(
top_i4s
&
TOP_MASK
)
|
I4s_TO_F16s_MAGIC_NUM
;
// I use inline PTX below because I am not sure if the compiler will emit
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// performance reliability over code readability.
...
@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers.
// Finally, we construct the output numbers.
// Convert elt_01
// Convert elt_01
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_23
h
[
0
]
=
h
[
0
]
-
FP16_TOP_MAGIC_NUM
;
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_23
// // Convert elt_45
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h
[
1
]
=
h
[
1
]
*
ONE_SIXTEENTH
+
NEG_64
;
// // Convert elt_67
// Convert elt_45
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h
[
2
]
=
h
[
2
]
-
FP16_TOP_MAGIC_NUM
;
// Convert elt_67
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h
[
3
]
=
h
[
3
]
*
ONE_SIXTEENTH
+
NEG_64
;
return
result
;
return
result
;
}
}
...
@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
...
@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
const
uint32_t
top_i4s
=
i4s
>>
8
;
printf
(
"=========common.h 133
\n
"
);
// if (0) { // 1024 & 64
// 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
h
[
0
]
=
(
i4s
&
BOT_MASK
)
|
MAGIC_NUM_2
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h
[
1
]
=
(
i4s
&
TOP_MASK
)
|
MAGIC_NUM_1
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
h
[
2
]
=
(
top_i4s
&
BOT_MASK
)
|
MAGIC_NUM_2
;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h
[
3
]
=
(
top_i4s
&
TOP_MASK
)
|
MAGIC_NUM_1
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
h
[
0
]
<<=
4
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
h
[
2
]
<<=
4
;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
// else { // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// h[0] <<= 4;
// h[2] <<= 4;
// // we don't need to subtract the magic nums because zeros will go through the same dequant function
// // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
return
result
;
return
result
;
}
}
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
{
uint32_t
smem_int_ptr
;
uint32_t
smem_int_ptr
;
...
@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
...
@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
{
{
uint
s
,
z
;
//
uint s, z;
(
half2
&
)
z
=
__halves2half2
(
q
.
x
,
q
.
x
);
//
(half2&)z = __halves2half2(q.x, q.x);
(
half2
&
)
s
=
__halves2half2
(
q
.
y
,
q
.
y
);
//
(half2&)s = __halves2half2(q.y, q.y);
auto
&
t
=
(
const
uint
&
)
x
;
//
auto& t = (const uint&)x;
uint
u
,
v
;
uint
v
;
// if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// }
// }
...
@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
...
@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// }
// }
printf
(
"=========common.h 235
\n
"
);
return
(
half2
&
)
v
;
return
(
half2
&
)
v
;
}
}
...
...
src/turbomind/kernels/gemm_s_f16/format.cu
View file @
d26f4c73
// Copyright (c) OpenMMLab. All rights reserved.
// Copyright (c) OpenMMLab. All rights reserved.
#include "common.h"
#include "common.h"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"
#include <iostream>
#include <iostream>
namespace
turbomind
{
namespace
turbomind
{
...
@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
...
@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
// permutation for [k, m/8] layout
// permutation for [k, m/8] layout
Array
<
int
,
10
>
shape
{
k
/
32
,
2
,
2
,
4
,
2
,
m
/
32
,
2
,
2
,
2
,
4
};
Array
<
int
,
10
>
shape
{
k
/
32
,
2
,
2
,
4
,
2
,
m
/
32
,
2
,
2
,
2
,
4
};
// |warp| lane | 2x2 | a0-7 |
// |warp| lane | 2x2 | a0-7 |
permute_u4
<
0
,
5
,
9
,
8
,
3
,
1
,
6
,
4
,
2
,
7
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
//permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
}
void
reformat_s4_k_m8_tarnsw4
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
)
{
// permutation for [k, m/8] layout
Array
<
int
,
10
>
shape
{
1
,
k
/
8
,
2
,
2
,
2
,
1
,
m
/
8
,
2
,
2
,
2
};
// 0123456-->4,6,7,5,0,3,1,2
//permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4
<
5
,
6
,
8
,
9
,
7
,
0
,
1
,
4
,
2
,
3
><<<
512
,
512
,
0
,
st
>>>
(
dst
,
src
,
shape
);
}
}
__global__
void
dequantize_s4_offset_64
(
uint4
*
dst
,
const
uint32_t
*
src
,
size_t
count
)
__global__
void
dequantize_s4_offset_64
(
uint4
*
dst
,
const
uint32_t
*
src
,
size_t
count
)
...
@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
...
@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
reformat_s4_k_m8
(
A_dst
,
A_src
,
m
,
k
,
st
);
reformat_s4_k_m8
(
A_dst
,
A_src
,
m
,
k
,
st
);
}
}
void
convert_s4_k_m8_
(
uint32_t
*
A_dst
,
half2
*
Q_dst
,
half
*
workspace
,
const
uint32_t
*
A_src
,
const
half
*
scales
,
const
uint32_t
*
qzeros
,
int
m
,
int
k
,
int
group_size
,
cudaStream_t
st
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
st
>>>
((
uint4
*
)
workspace
,
qzeros
,
k
/
group_size
*
m
/
8
);
merge_Q
<<<
256
,
256
,
0
,
st
>>>
(
Q_dst
,
scales
,
workspace
,
k
/
group_size
*
m
);
reformat_s4_k_m8_tarnsw4
(
A_dst
,
A_src
,
m
,
k
,
st
);
}
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
)
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
)
{
{
Array
<
int
,
7
>
shape
{
k
,
m
/
size_per_head
,
2
,
size_per_head
/
2
/
8
,
2
,
2
,
2
};
Array
<
int
,
7
>
shape
{
k
,
m
/
size_per_head
,
2
,
size_per_head
/
2
/
8
,
2
,
2
,
2
};
...
@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
...
@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
{
{
dequantize_s4_kernel
<<<
512
,
512
>>>
(
dst
,
src
,
count
);
dequantize_s4_kernel
<<<
512
,
512
>>>
(
dst
,
src
,
count
);
}
}
__global__
void
dequant_kernel
(
int
num_kernels
,
half
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int
j
=
id
%
n
;
int
i
=
id
/
n
;
half
x
=
zeros_and_scales
[
i
/
group_size
*
n
+
j
].
data
[
0
];
half
y
=
zeros_and_scales
[
i
/
group_size
*
n
+
j
].
data
[
1
];
float
tmp
=
(
weight
[
id
]
-
x
)
*
y
;
weight
[
id
]
=
__float2half
(
tmp
);
}
__global__
void
dequant_kernel_colmajor
(
int
num_kernels
,
half
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int
j
=
id
/
group_size
;
half
x
=
zeros_and_scales
[
j
].
data
[
0
];
half
y
=
zeros_and_scales
[
j
].
data
[
1
];
float
tmp
=
(
weight
[
id
]
-
x
)
*
y
;
weight
[
id
]
=
__float2half
(
tmp
);
}
void
dequant_w4_gemm
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
stream
>>>
((
uint4
*
)
output
,
weight
,
k
*
n
/
8
);
int
num_kernels
=
k
*
n
;
dequant_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
zeros_and_scales
,
k
,
n
,
group_size
);
}
void
dequant_w4_gemm_colmajor
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
)
{
dequantize_s4_offset_64
<<<
256
,
256
,
0
,
stream
>>>
((
uint4
*
)
output
,
weight
,
k
*
n
/
8
);
int
num_kernels
=
k
*
n
;
dequant_kernel_colmajor
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
zeros_and_scales
,
k
,
n
,
group_size
);
}
__global__
void
FusedSiluActivation_kernel
(
int
num_kernels
,
half
*
output
,
const
uint32_t
*
src
,
int
m
,
int
n
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
auto
data
=
((
half2
*
)
src
)[
id
];
float
x
=
__half2float
(
data
.
data
[
0
]);
float
y
=
__half2float
(
data
.
data
[
1
]);
float
silu
=
x
/
(
1.
f
+
__expf
(
-
x
))
*
y
;
output
[
id
]
=
__float2half
(
silu
);
}
__global__
void
assign_kernel
(
int
num_kernels
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
output
[
id
]
=
src
[
id
];
}
void
addFusedSiluActivation
(
cudaStream_t
stream
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
,
int
type
)
{
int
num_kernels
=
m
*
n
;
switch
(
type
)
{
case
0
:
assign_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
src
,
m
,
n
);
break
;
case
1
:
FusedSiluActivation_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
int
(
num_kernels
/
2
),
output
,(
const
uint32_t
*
)
src
,
m
,
n
);
break
;
default:
return
;
}
}
}
// namespace turbomind
}
// namespace turbomind
src/turbomind/kernels/gemm_s_f16/format.h
View file @
d26f4c73
...
@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
...
@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
int
group_size
,
int
group_size
,
cudaStream_t
st
=
{});
cudaStream_t
st
=
{});
void
convert_s4_k_m8_
(
uint32_t
*
A_dst
,
half2
*
Q_dst
,
half
*
workspace
,
const
uint32_t
*
A_src
,
const
half
*
scales
,
const
uint32_t
*
qzeros
,
int
m
,
int
k
,
int
group_size
,
cudaStream_t
st
=
{});
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
=
{});
void
transpose_qk_s4_k_m8_hf
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
int
size_per_head
,
cudaStream_t
st
=
{});
void
fuse_w1_w3_s4_k_m8
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
=
{});
void
fuse_w1_w3_s4_k_m8
(
uint32_t
*
dst
,
const
uint32_t
*
src
,
int
m
,
int
k
,
cudaStream_t
st
=
{});
...
...
src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h
View file @
d26f4c73
...
@@ -13,6 +13,10 @@
...
@@ -13,6 +13,10 @@
namespace
turbomind
{
namespace
turbomind
{
extern
bool
g_dump_kernel_info_once
;
extern
bool
g_dump_kernel_info_once
;
void
dequant_w4_gemm
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
);
void
addFusedSiluActivation
(
cudaStream_t
stream
,
half
*
output
,
const
half
*
src
,
int
m
,
int
n
,
int
type
);
void
dequant_w4_gemm_colmajor
(
cudaStream_t
stream
,
half
*
output
,
const
uint32_t
*
weight
,
const
half2
*
zeros_and_scales
,
int
k
,
int
n
,
int
group_size
);
class
GemmS4F16
{
class
GemmS4F16
{
public:
public:
...
...
src/turbomind/kernels/gemm_s_f16/gemm_template.h
View file @
d26f4c73
This diff is collapsed.
Click to expand it.
src/turbomind/models/llama/BlockManager.cc
View file @
d26f4c73
...
@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
...
@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
return
false
;
return
false
;
}
}
//auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
auto
ptr
=
(
uint8_t
*
)
allocator_
->
malloc
(
block_size_
*
chunk_size
);
auto
ptr
=
(
uint8_t
*
)
allocator_
->
malloc
(
block_size_
*
chunk_size
);
if
(
!
ptr
)
{
if
(
!
ptr
)
{
return
false
;
return
false
;
...
...
src/turbomind/models/llama/CMakeLists.txt
View file @
d26f4c73
...
@@ -19,13 +19,14 @@ add_library(Llama STATIC
...
@@ -19,13 +19,14 @@ add_library(Llama STATIC
unified_attention_layer.cc
unified_attention_layer.cc
llama_kernels.cu
llama_kernels.cu
llama_decoder_kernels.cu
llama_decoder_kernels.cu
llama_utils.cu
)
llama_utils.cu
./awq_sugon/gemm_w4_dequation.cu
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-fPIC"
)
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries
(
Llama PUBLIC cudart
target_link_libraries
(
Llama PUBLIC cudart
#
gemm_s4_f16
gemm_s4_f16
cublasMMWrapper
cublasMMWrapper
DynamicDecodeLayer
DynamicDecodeLayer
activation_kernels
activation_kernels
...
@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart
...
@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart
memory_utils
memory_utils
nccl_utils
nccl_utils
cuda_utils
cuda_utils
logger
)
logger
gemm_multiB_int4
)
# llama_fmha)
# llama_fmha)
if
(
NOT MSVC
)
if
(
NOT MSVC
)
...
...
src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
View file @
d26f4c73
...
@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
...
@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t
inter_size
,
size_t
inter_size
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
bool
attn_bias
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
)
:
size_t
tensor_para_rank
)
:
...
@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
...
@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights
.
qkv
.
output_dims
=
(
head_num
+
2
*
kv_head_num
)
*
size_per_head
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
output_dims
=
(
head_num
+
2
*
kv_head_num
)
*
size_per_head
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
qkv
.
group_size
=
group_size
;
self_attn_weights
.
qkv
.
group_size
=
group_size
;
self_attn_weights
.
qkv
.
w4_weight_layout
=
w4_weight_layout
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
type
=
weight_type
;
self_attn_weights
.
output
.
type
=
weight_type
;
self_attn_weights
.
output
.
group_size
=
group_size
;
self_attn_weights
.
output
.
group_size
=
group_size
;
self_attn_weights
.
output
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
gating
.
group_size
=
group_size
;
ffn_weights
.
gating
.
group_size
=
group_size
;
ffn_weights
.
gating
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
group_size
=
group_size
;
ffn_weights
.
intermediate
.
group_size
=
group_size
;
ffn_weights
.
intermediate
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
fused_gating_intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
fused_gating_intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
fused_gating_intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
*
2
;
ffn_weights
.
fused_gating_intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
*
2
;
ffn_weights
.
fused_gating_intermediate
.
type
=
weight_type
;
ffn_weights
.
fused_gating_intermediate
.
type
=
weight_type
;
ffn_weights
.
fused_gating_intermediate
.
group_size
=
group_size
;
ffn_weights
.
fused_gating_intermediate
.
group_size
=
group_size
;
ffn_weights
.
fused_gating_intermediate
.
w4_weight_layout
=
w4_weight_layout
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
output
.
group_size
=
group_size
;
ffn_weights
.
output
.
group_size
=
group_size
;
ffn_weights
.
output
.
w4_weight_layout
=
w4_weight_layout
;
mallocWeights
();
mallocWeights
();
}
}
...
@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
...
@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else
{
// int8, int4
else
{
// int8, int4
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
FT_CHECK
(
weights
.
input_dims
%
factor
==
0
);
FT_CHECK
(
weights
.
input_dims
%
factor
==
0
);
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// //读环境变量
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// int m_weightlayout_switch=1;
// interleaved scales/zeros
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if
((
weights
.
input_dims
%
4096
==
0
)
&&
(
weights
.
w4_weight_layout
==
1
||
weights
.
w4_weight_layout
==
2
))
{
size_t
new_input_dims
=
weights
.
input_dims
+
2
*
weights
.
group_size
;
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
new_input_dims
*
weights
.
output_dims
/
factor
);
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
new_input_dims
*
weights
.
output_dims
/
factor
);
// interleaved scales/zeros
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
new_input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
}
else
{
deviceMalloc
((
int
**
)
&
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
deviceMemSetZero
((
int
*
)
weights
.
kernel
,
weights
.
input_dims
*
weights
.
output_dims
/
factor
);
// interleaved scales/zeros
deviceMalloc
((
T
**
)
&
weights
.
scales_and_zeros
,
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
);
}
}
}
}
}
...
@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
...
@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
}
}
else
{
// int8, int4
else
{
// int8, int4
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
const
int
factor
=
sizeof
(
float
)
*
8
/
bit_size
;
output
.
insert
(
get_name
(
"qweight"
),
// //读环境变量
Tensor
{
MEMORY_GPU
,
// int m_weightlayout_switch=1;
TYPE_INT32
,
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
{
weights
.
input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
// if (env_weightlayout_str != nullptr) {
weights
.
kernel
});
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
output
.
insert
(
get_name
(
"scales_zeros"
),
// }
Tensor
{
MEMORY_GPU
,
if
((
weights
.
input_dims
%
4096
==
0
)
&&
(
weights
.
w4_weight_layout
==
1
||
weights
.
w4_weight_layout
==
2
))
getTensorType
<
T
>
(),
{
{
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
size_t
new_input_dims
=
weights
.
input_dims
+
weights
.
group_size
;
weights
.
scales_and_zeros
});
output
.
insert
(
get_name
(
"qweight"
),
Tensor
{
MEMORY_GPU
,
TYPE_INT32
,
{
new_input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
weights
.
kernel
});
output
.
insert
(
get_name
(
"scales_zeros"
),
Tensor
{
MEMORY_GPU
,
getTensorType
<
T
>
(),
{
new_input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
weights
.
scales_and_zeros
});
}
else
{
output
.
insert
(
get_name
(
"qweight"
),
Tensor
{
MEMORY_GPU
,
TYPE_INT32
,
{
weights
.
input_dims
*
weights
.
output_dims
*
sizeof
(
int
)
/
factor
},
weights
.
kernel
});
output
.
insert
(
get_name
(
"scales_zeros"
),
Tensor
{
MEMORY_GPU
,
getTensorType
<
T
>
(),
{
weights
.
input_dims
/
weights
.
group_size
*
weights
.
output_dims
*
2
*
sizeof
(
T
)},
weights
.
scales_and_zeros
});
}
}
}
}
}
...
@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
...
@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
FT_CHECK
(
dim1
%
factor
==
0
);
FT_CHECK
(
dim1
%
factor
==
0
);
std
::
vector
<
size_t
>
w_shape
{
dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
// //读环境变量
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if
((
dim0
%
4096
==
0
)
&&
(
w
.
w4_weight_layout
==
1
||
w
.
w4_weight_layout
==
2
))
{
size_t
new_dim0
=
dim0
+
2
*
w
.
group_size
;
std
::
vector
<
size_t
>
w_shape
{
new_dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
const
size_t
group_count
=
w
.
group_size
>
0
?
new_dim0
/
w
.
group_size
:
1
;
const
size_t
group_count
=
w
.
group_size
>
0
?
dim0
/
w
.
group_size
:
1
;
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
}
else
{
std
::
vector
<
size_t
>
w_shape
{
dim0
,
dim1
/
factor
*
sizeof
(
uint32_t
)};
loadWeightFromBin
((
int8_t
*
)
w
.
kernel
,
w_shape
,
prefix
+
".qweight"
,
FtCudaDataType
::
INT8
,
{});
const
size_t
group_count
=
w
.
group_size
>
0
?
dim0
/
w
.
group_size
:
1
;
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
loadWeightFromBin
((
half
*
)
w
.
scales_and_zeros
,
{
group_count
,
dim1
*
2
},
prefix
+
".scales_zeros"
,
type
,
{});
}
}
}
}
}
...
...
src/turbomind/models/llama/LlamaDecoderLayerWeight.h
View file @
d26f4c73
...
@@ -35,6 +35,7 @@ public:
...
@@ -35,6 +35,7 @@ public:
size_t
inter_size
,
size_t
inter_size
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
bool
attn_bias
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
);
size_t
tensor_para_rank
);
...
...
src/turbomind/models/llama/LlamaDenseWeight.h
View file @
d26f4c73
...
@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
...
@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
T
*
bias
;
T
*
bias
;
T
*
scales_and_zeros
;
T
*
scales_and_zeros
;
int
group_size
;
int
group_size
;
int
w4_weight_layout
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
src/turbomind/models/llama/LlamaFfnLayer.cc
View file @
d26f4c73
...
@@ -29,7 +29,7 @@ namespace turbomind {
...
@@ -29,7 +29,7 @@ namespace turbomind {
template
<
typename
T
>
template
<
typename
T
>
void
LlamaFfnLayer
<
T
>::
allocateBuffer
(
size_t
token_num
)
void
LlamaFfnLayer
<
T
>::
allocateBuffer
(
size_t
token_num
)
{
{
inter_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
inter_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
inter_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
inter_buf_
,
2
*
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
gating_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
gating_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
gating_buf_
=
(
T
*
)
allocator_
->
reMalloc
(
gating_buf_
,
sizeof
(
T
)
*
token_num
*
inter_size_
,
false
);
is_allocate_buffer_
=
true
;
is_allocate_buffer_
=
true
;
}
}
...
@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
...
@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
if
(
weights
->
fused_gating_intermediate
.
kernel
)
{
if
(
weights
->
fused_gating_intermediate
.
kernel
)
{
NvtxScope
scope
(
"fused_silu_ffn"
);
NvtxScope
scope
(
"fused_silu_ffn"
);
linear_
.
forward
(
// linear_.forward(
gating_buf_
,
ffn_input_data
,
num_token
,
weights
->
fused_gating_intermediate
,
LlamaLinear
<
T
>::
kFusedSiluFfn
);
// gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
linear_
.
forward_ffn
(
gating_buf_
,
inter_buf_
,
ffn_input_data
,
num_token
,
weights
->
fused_gating_intermediate
,
LlamaLinear
<
T
>::
kFusedSiluFfn
);
}
}
else
{
else
{
{
// w1(x)
{
// w1(x)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment