Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
2700abb3
Unverified
Commit
2700abb3
authored
Jun 24, 2023
by
Li Zhang
Committed by
GitHub
Jun 24, 2023
Browse files
Support attention bias (#14)
* support attention bias * fix conflict
parent
ee962784
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
177 additions
and
122 deletions
+177
-122
llmdeploy/serve/fastertransformer/deploy.py
llmdeploy/serve/fastertransformer/deploy.py
+40
-21
src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc
...tertransformer/models/llama/LlamaContextAttentionLayer.cc
+6
-5
src/fastertransformer/models/llama/LlamaContextDecoder.cc
src/fastertransformer/models/llama/LlamaContextDecoder.cc
+18
-15
src/fastertransformer/models/llama/LlamaDecoder.cc
src/fastertransformer/models/llama/LlamaDecoder.cc
+18
-15
src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc
...fastertransformer/models/llama/LlamaDecoderLayerWeight.cc
+21
-16
src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h
src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h
+15
-9
src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc
...ransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc
+4
-4
src/fastertransformer/models/llama/LlamaWeight.cc
src/fastertransformer/models/llama/LlamaWeight.cc
+4
-2
src/fastertransformer/models/llama/LlamaWeight.h
src/fastertransformer/models/llama/LlamaWeight.h
+3
-1
src/fastertransformer/models/llama/llama_decoder_kernels.cu
src/fastertransformer/models/llama/llama_decoder_kernels.cu
+31
-22
src/fastertransformer/models/llama/llama_decoder_kernels.h
src/fastertransformer/models/llama/llama_decoder_kernels.h
+2
-2
src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc
...astertransformer/triton_backend/llama/LlamaTritonModel.cc
+12
-9
src/fastertransformer/triton_backend/llama/LlamaTritonModel.h
...fastertransformer/triton_backend/llama/LlamaTritonModel.h
+3
-1
No files found.
llmdeploy/serve/fastertransformer/deploy.py
View file @
2700abb3
...
@@ -81,6 +81,8 @@ def export(model_name: str,
...
@@ -81,6 +81,8 @@ def export(model_name: str,
param
=
param
.
half
()
param
=
param
.
half
()
param
.
contiguous
().
numpy
().
tofile
(
osp
.
join
(
out_dir
,
name
))
param
.
contiguous
().
numpy
().
tofile
(
osp
.
join
(
out_dir
,
name
))
attn_bias
=
False
# reverse the splitting axes since the weights are transposed above
# reverse the splitting axes since the weights are transposed above
for
param_name
,
param_data
in
model_params
.
items
():
for
param_name
,
param_data
in
model_params
.
items
():
if
param_name
==
'tok_embeddings.weight'
:
if
param_name
==
'tok_embeddings.weight'
:
...
@@ -88,13 +90,18 @@ def export(model_name: str,
...
@@ -88,13 +90,18 @@ def export(model_name: str,
head_num
=
dim
//
size_per_head
head_num
=
dim
//
size_per_head
split_dim
=
None
split_dim
=
None
key
,
ext
=
param_name
.
split
(
'.'
)[
-
2
:]
key
,
ext
=
param_name
.
split
(
'.'
)[
-
2
:]
if
key
==
'w_qkv'
and
ext
==
'bias'
:
attn_bias
=
True
copy
=
False
copy
=
False
if
key
in
[
'w1'
,
'w3'
,
'w_qkv'
]:
if
key
in
[
'w1'
,
'w3'
,
'w_qkv'
]:
split_dim
=
-
1
if
ext
in
[
'bias'
]:
if
key
==
'w1'
:
copy
=
True
inter_size
=
param_data
.
shape
[
-
1
]
else
:
split_dim
=
-
1
if
key
==
'w1'
:
inter_size
=
param_data
.
shape
[
-
1
]
elif
key
in
[
'w2'
,
'wo'
]:
elif
key
in
[
'w2'
,
'wo'
]:
if
ext
in
[
'scales'
,
'zeros'
]:
if
ext
in
[
'scales'
,
'zeros'
,
'bias'
]:
copy
=
True
copy
=
True
else
:
else
:
split_dim
=
0
split_dim
=
0
...
@@ -129,6 +136,7 @@ def export(model_name: str,
...
@@ -129,6 +136,7 @@ def export(model_name: str,
rotary_embedding
=
size_per_head
,
rotary_embedding
=
size_per_head
,
inter_size
=
inter_size
,
inter_size
=
inter_size
,
norm_eps
=
norm_eps
,
norm_eps
=
norm_eps
,
attn_bias
=
attn_bias
,
start_id
=
bos_id
,
start_id
=
bos_id
,
end_id
=
eos_id
,
end_id
=
eos_id
,
weight_type
=
'fp16'
,
weight_type
=
'fp16'
,
...
@@ -189,20 +197,28 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
...
@@ -189,20 +197,28 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
for
i
,
ckpt_path
in
enumerate
(
checkpoints
):
for
i
,
ckpt_path
in
enumerate
(
checkpoints
):
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
)
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
)
for
param_name
,
param_data
in
ckpt
.
items
():
for
param_name
,
param_data
in
ckpt
.
items
():
key
=
param_name
.
split
(
'.'
)[
-
2
]
key
,
ext
=
param_name
.
split
(
'.'
)[
-
2
:
]
# column-parallel
# column-parallel
if
key
in
[
'w1'
,
'w3'
,
'wq'
,
'wk'
,
'wv'
,
'output'
]:
if
key
in
[
'w1'
,
'w3'
,
'wq'
,
'wk'
,
'wv'
,
'output'
]:
size
=
param_data
.
size
(
0
)
size
=
param_data
.
size
(
0
)
param
=
get_param
(
if
ext
==
'weight'
:
param_name
,
param
=
get_param
(
[
size
*
n_ckpt
,
param_data
.
size
(
1
)])
param_name
,
[
size
*
n_ckpt
,
param_data
.
size
(
1
)])
param
.
data
[
size
*
i
:
size
*
(
i
+
1
),
:]
=
param_data
param
.
data
[
size
*
i
:
size
*
(
i
+
1
),
:]
=
param_data
else
:
# bias
param
=
get_param
(
param_name
,
[
size
*
n_ckpt
])
param
.
data
[
size
*
i
:
size
*
(
i
+
1
)]
=
param_data
# row-parallel
# row-parallel
elif
key
in
[
'w2'
,
'wo'
,
'tok_embeddings'
]:
elif
key
in
[
'w2'
,
'wo'
,
'tok_embeddings'
]:
size
=
param_data
.
size
(
-
1
)
size
=
param_data
.
size
(
-
1
)
param
=
get_param
(
param_name
,
if
ext
==
'weight'
:
[
param_data
.
size
(
0
),
size
*
n_ckpt
])
param
=
get_param
(
param_name
,
param
.
data
[:,
size
*
i
:
size
*
(
i
+
1
)]
=
param_data
[
param_data
.
size
(
0
),
size
*
n_ckpt
])
param
.
data
[:,
size
*
i
:
size
*
(
i
+
1
)]
=
param_data
else
:
# bias
param
=
get_param
(
param_name
,
[
size
])
param
.
data
=
param_data
elif
i
==
0
:
elif
i
==
0
:
param
=
get_param
(
param_name
,
param_data
.
size
())
param
=
get_param
(
param_name
,
param_data
.
size
())
param
.
data
=
param_data
param
.
data
=
param_data
...
@@ -216,15 +232,18 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
...
@@ -216,15 +232,18 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
param
.
data
=
param
.
data
.
t
()
param
.
data
=
param
.
data
.
t
()
# concat qkv projection
# concat qkv projection
for
i
in
range
(
1000
):
for
t
in
[
'weight'
,
'bias'
]:
_qkv
=
[
f
'layers.
{
i
}
.attention.
{
k
}
.weight'
for
k
in
[
'wq'
,
'wk'
,
'wv'
]]
for
i
in
range
(
1000
):
try
:
_qkv
=
[
f
'layers.
{
i
}
.attention.
{
k
}
.
{
t
}
'
for
k
in
[
qkv
=
tuple
(
map
(
model_params
.
pop
,
_qkv
))
'wq'
,
'wk'
,
'wv'
]]
except
KeyError
:
try
:
break
qkv
=
tuple
(
map
(
model_params
.
pop
,
_qkv
))
qkv
=
torch
.
stack
(
qkv
,
dim
=
1
)
except
KeyError
:
model_params
[
f
'layers.
{
i
}
.attention.w_qkv.weight'
]
=
qkv
break
print
(
qkv
.
shape
,
qkv
.
dtype
)
# concat by output_dims
qkv
=
torch
.
stack
(
qkv
,
dim
=
qkv
[
0
].
dim
()
-
1
)
print
(
f
'layers.
{
i
}
.attention.w_qkv.
{
t
}
'
,
qkv
.
shape
)
model_params
[
f
'layers.
{
i
}
.attention.w_qkv.
{
t
}
'
]
=
qkv
assert
num_layer
==
i
,
f
'miss matched layers:
{
num_layer
}
vs
{
i
}
'
assert
num_layer
==
i
,
f
'miss matched layers:
{
num_layer
}
vs
{
i
}
'
...
...
src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc
View file @
2700abb3
...
@@ -15,8 +15,9 @@
...
@@ -15,8 +15,9 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc
#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h"
#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
...
@@ -157,9 +158,9 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
...
@@ -157,9 +158,9 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
v_buf_2_
,
v_buf_2_
,
PrefixPromptBatchWeightsParam
<
T
>
{},
PrefixPromptBatchWeightsParam
<
T
>
{},
qkv_buf_
,
qkv_buf_
,
(
const
T
*
)
nullptr
,
//
qkv
_
bias
weights
->
qkv
.
bias
,
padding_offset
,
// padding_offset,
padding_offset
,
// padding_offset,
history_length
,
// used for applying rotary embedding
history_length
,
// used for applying rotary embedding
batch_size
,
batch_size
,
max_q_len
,
// seq_len
max_q_len
,
// seq_len
num_token
,
// batch_size * seq_len
num_token
,
// batch_size * seq_len
...
...
src/fastertransformer/models/llama/LlamaContextDecoder.cc
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
#include "src/fastertransformer/models/llama/LlamaContextDecoder.h"
#include "src/fastertransformer/models/llama/LlamaContextDecoder.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
...
@@ -243,13 +244,14 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
...
@@ -243,13 +244,14 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
/// self-attention
/// self-attention
forwardSelfAttn
(
sess
,
input_tensors
,
layer
,
false
);
forwardSelfAttn
(
sess
,
input_tensors
,
layer
,
false
);
invokeFusedAddResidualRMSNorm
(
decoder_input_output
,
invokeFusedAddBiasResidualRMSNorm
(
decoder_input_output
,
attn_ffn_io_
,
attn_ffn_io_
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_norm_weights
,
decoder_layer_weights
->
at
(
layer
)
->
self_attn_weights
.
output
.
bias
,
rmsnorm_eps_
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_norm_weights
,
sess
.
token_num
,
rmsnorm_eps_
,
hidden_units_
,
sess
.
token_num
,
stream_
);
hidden_units_
,
stream_
);
sync_check_cuda_error
();
sync_check_cuda_error
();
////////////////////////////////////////////
////////////////////////////////////////////
...
@@ -260,13 +262,14 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
...
@@ -260,13 +262,14 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
auto
scale_weight
=
layer
<
num_layer_
-
1
?
decoder_layer_weights
->
at
(
layer
+
1
)
->
self_attn_norm_weights
:
auto
scale_weight
=
layer
<
num_layer_
-
1
?
decoder_layer_weights
->
at
(
layer
+
1
)
->
self_attn_norm_weights
:
input_tensors
->
at
(
"output_norm_weight"
).
getPtr
<
T
>
();
input_tensors
->
at
(
"output_norm_weight"
).
getPtr
<
T
>
();
invokeFusedAddResidualRMSNorm
(
decoder_input_output
,
//
invokeFusedAddBiasResidualRMSNorm
(
decoder_input_output
,
//
attn_ffn_io_
,
attn_ffn_io_
,
scale_weight
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_weights
.
output
.
bias
,
rmsnorm_eps_
,
scale_weight
,
sess
.
token_num
,
rmsnorm_eps_
,
hidden_units_
,
sess
.
token_num
,
stream_
);
hidden_units_
,
stream_
);
sync_check_cuda_error
();
sync_check_cuda_error
();
}
}
...
...
src/fastertransformer/models/llama/LlamaDecoder.cc
View file @
2700abb3
...
@@ -16,7 +16,8 @@
...
@@ -16,7 +16,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc
#include "src/fastertransformer/models/llama/LlamaDecoder.h"
#include "src/fastertransformer/models/llama/LlamaDecoder.h"
#include "src/fastertransformer/models/llama/llama_decoder_kernels.h"
#include "src/fastertransformer/models/llama/llama_decoder_kernels.h"
...
@@ -205,13 +206,14 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
...
@@ -205,13 +206,14 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn
(
sess
,
decoder_output
,
input_tensors
,
layer
);
forwardSelfAttn
(
sess
,
decoder_output
,
input_tensors
,
layer
);
invokeFusedAddResidualRMSNorm
(
decoder_input
,
invokeFusedAddBiasResidualRMSNorm
(
decoder_input
,
decoder_output
,
decoder_output
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_norm_weights
,
decoder_layer_weights
->
at
(
layer
)
->
self_attn_weights
.
output
.
bias
,
rmsnorm_eps_
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_norm_weights
,
sess
.
batch_size
,
rmsnorm_eps_
,
hidden_units_
,
sess
.
batch_size
,
stream_
);
hidden_units_
,
stream_
);
sync_check_cuda_error
();
sync_check_cuda_error
();
// decoder_layer_output_ = ffn(decoder_normed_input_)
// decoder_layer_output_ = ffn(decoder_normed_input_)
...
@@ -219,13 +221,14 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
...
@@ -219,13 +221,14 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
auto
scale_weight
=
layer
<
num_layer_
-
1
?
decoder_layer_weights
->
at
(
layer
+
1
)
->
self_attn_norm_weights
:
auto
scale_weight
=
layer
<
num_layer_
-
1
?
decoder_layer_weights
->
at
(
layer
+
1
)
->
self_attn_norm_weights
:
input_tensors
->
at
(
"output_norm_weight"
).
getPtr
<
T
>
();
input_tensors
->
at
(
"output_norm_weight"
).
getPtr
<
T
>
();
invokeFusedAddResidualRMSNorm
(
decoder_input
,
//
invokeFusedAddBiasResidualRMSNorm
(
decoder_input
,
//
decoder_output
,
decoder_output
,
scale_weight
,
decoder_layer_weights
->
at
(
layer
)
->
ffn_weights
.
output
.
bias
,
rmsnorm_eps_
,
scale_weight
,
sess
.
batch_size
,
rmsnorm_eps_
,
hidden_units_
,
sess
.
batch_size
,
stream_
);
hidden_units_
,
stream_
);
sync_check_cuda_error
();
sync_check_cuda_error
();
}
}
...
...
src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc
View file @
2700abb3
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h"
#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h"
#include "src/fastertransformer/utils/logger.h"
#include "src/fastertransformer/utils/logger.h"
...
@@ -25,33 +25,38 @@
...
@@ -25,33 +25,38 @@
namespace
fastertransformer
{
namespace
fastertransformer
{
template
<
typename
T
>
template
<
typename
T
>
LlamaDecoderLayerWeight
<
T
>::
LlamaDecoderLayerWeight
(
LlamaDecoderLayerWeight
<
T
>::
LlamaDecoderLayerWeight
(
size_t
hidden_units
,
size_t
hidden_units
,
size_t
inter_size
,
WeightType
weight_type
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
)
:
size_t
inter_size
,
WeightType
weight_type
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
)
:
hidden_units_
(
hidden_units
),
hidden_units_
(
hidden_units
),
inter_size_
(
inter_size
),
inter_size_
(
inter_size
),
weight_type_
(
weight_type
),
weight_type_
(
weight_type
),
attn_bias_
(
attn_bias
),
tensor_para_size_
(
tensor_para_size
),
tensor_para_size_
(
tensor_para_size
),
tensor_para_rank_
(
tensor_para_rank
)
tensor_para_rank_
(
tensor_para_rank
)
{
{
self_attn_weights
.
qkv
.
input_dims
=
hidden_units_
;
self_attn_weights
.
qkv
.
input_dims
=
hidden_units_
;
self_attn_weights
.
qkv
.
output_dims
=
3
*
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
output_dims
=
3
*
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
qkv
.
type
=
weight_type
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
input_dims
=
hidden_units_
/
tensor_para_size_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
output_dims
=
hidden_units_
;
self_attn_weights
.
output
.
type
=
weight_type
;
self_attn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
input_dims
=
hidden_units_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
gating
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
input_dims
=
hidden_units_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
output_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
intermediate
.
type
=
weight_type
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
input_dims
=
inter_size_
/
tensor_para_size_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
output_dims
=
hidden_units_
;
ffn_weights
.
output
.
type
=
weight_type
;
ffn_weights
.
output
.
type
=
weight_type
;
mallocWeights
();
mallocWeights
();
}
}
...
@@ -117,8 +122,8 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
...
@@ -117,8 +122,8 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
deviceMalloc
((
T
**
)
&
self_attn_norm_weights
,
hidden_units_
);
deviceMalloc
((
T
**
)
&
self_attn_norm_weights
,
hidden_units_
);
deviceMalloc
((
T
**
)
&
ffn_norm_weights
,
hidden_units_
);
deviceMalloc
((
T
**
)
&
ffn_norm_weights
,
hidden_units_
);
fastertransformer
::
mallocWeights
(
self_attn_weights
.
qkv
,
false
);
fastertransformer
::
mallocWeights
(
self_attn_weights
.
qkv
,
attn_bias_
);
fastertransformer
::
mallocWeights
(
self_attn_weights
.
output
,
false
);
fastertransformer
::
mallocWeights
(
self_attn_weights
.
output
,
attn_bias_
);
fastertransformer
::
mallocWeights
(
ffn_weights
.
gating
,
false
);
fastertransformer
::
mallocWeights
(
ffn_weights
.
gating
,
false
);
fastertransformer
::
mallocWeights
(
ffn_weights
.
intermediate
,
false
);
fastertransformer
::
mallocWeights
(
ffn_weights
.
intermediate
,
false
);
...
...
src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h
#pragma once
#pragma once
...
@@ -27,8 +28,12 @@ template<typename T>
...
@@ -27,8 +28,12 @@ template<typename T>
struct
LlamaDecoderLayerWeight
{
struct
LlamaDecoderLayerWeight
{
public:
public:
LlamaDecoderLayerWeight
()
=
delete
;
LlamaDecoderLayerWeight
()
=
delete
;
LlamaDecoderLayerWeight
(
LlamaDecoderLayerWeight
(
size_t
hidden_units
,
size_t
hidden_units
,
size_t
inter_size
,
WeightType
weight_type
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
);
size_t
inter_size
,
WeightType
weight_type
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
);
~
LlamaDecoderLayerWeight
();
~
LlamaDecoderLayerWeight
();
LlamaDecoderLayerWeight
(
const
LlamaDecoderLayerWeight
&
other
)
=
delete
;
LlamaDecoderLayerWeight
(
const
LlamaDecoderLayerWeight
&
other
)
=
delete
;
LlamaDecoderLayerWeight
&
operator
=
(
const
LlamaDecoderLayerWeight
&
other
)
=
delete
;
LlamaDecoderLayerWeight
&
operator
=
(
const
LlamaDecoderLayerWeight
&
other
)
=
delete
;
...
@@ -41,13 +46,14 @@ public:
...
@@ -41,13 +46,14 @@ public:
LlamaFfnWeight
<
T
>
ffn_weights
{};
LlamaFfnWeight
<
T
>
ffn_weights
{};
private:
private:
size_t
hidden_units_
;
size_t
hidden_units_
;
size_t
inter_size_
;
size_t
inter_size_
;
WeightType
weight_type_
;
WeightType
weight_type_
;
size_t
bit_size_
;
size_t
bit_size_
;
size_t
tensor_para_size_
;
bool
attn_bias_
;
size_t
tensor_para_rank_
;
size_t
tensor_para_size_
;
bool
is_maintain_buffer_
=
false
;
size_t
tensor_para_rank_
;
bool
is_maintain_buffer_
=
false
;
void
mallocWeights
();
void
mallocWeights
();
};
};
...
...
src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc
View file @
2700abb3
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
...
@@ -237,8 +237,8 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
...
@@ -237,8 +237,8 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
fusedQKV_masked_attention_dispatch
<
T
>
(
fusedQKV_masked_attention_dispatch
<
T
>
(
qkv_buf_
,
qkv_buf_
,
nullptr
,
// query_weight.bias,
weights
->
qkv
.
bias
,
// query_weight.bias,
nullptr
,
// relative_attention_bias,
nullptr
,
// relative_attention_bias,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
key_cache_ptrs
,
key_cache_ptrs
,
...
...
src/fastertransformer/models/llama/LlamaWeight.cc
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc
#include "src/fastertransformer/models/llama/LlamaWeight.h"
#include "src/fastertransformer/models/llama/LlamaWeight.h"
...
@@ -27,6 +28,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
...
@@ -27,6 +28,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
size_t
vocab_size
,
size_t
vocab_size
,
size_t
num_layer
,
size_t
num_layer
,
WeightType
weight_type
,
WeightType
weight_type
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
,
size_t
tensor_para_rank
,
int
prefix_cache_len
)
:
int
prefix_cache_len
)
:
...
@@ -42,7 +44,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
...
@@ -42,7 +44,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
decoder_layer_weights
.
reserve
(
num_layer_
);
decoder_layer_weights
.
reserve
(
num_layer_
);
for
(
unsigned
l
=
0
;
l
<
num_layer_
;
++
l
)
{
for
(
unsigned
l
=
0
;
l
<
num_layer_
;
++
l
)
{
decoder_layer_weights
.
push_back
(
new
LlamaDecoderLayerWeight
<
T
>
(
decoder_layer_weights
.
push_back
(
new
LlamaDecoderLayerWeight
<
T
>
(
hidden_units_
,
inter_size_
,
weight_type_
,
tensor_para_size_
,
tensor_para_rank_
));
hidden_units_
,
inter_size_
,
weight_type_
,
attn_bias
,
tensor_para_size_
,
tensor_para_rank_
));
}
}
mallocWeights
();
mallocWeights
();
...
...
src/fastertransformer/models/llama/LlamaWeight.h
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h
#pragma once
#pragma once
...
@@ -32,6 +33,7 @@ struct LlamaWeight {
...
@@ -32,6 +33,7 @@ struct LlamaWeight {
size_t
vocab_size
,
size_t
vocab_size
,
size_t
num_layer
,
size_t
num_layer
,
WeightType
weight_type
,
WeightType
weight_type
,
bool
attn_bias
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
,
size_t
tensor_para_rank
,
int
prefix_cache_len
);
int
prefix_cache_len
);
...
...
src/fastertransformer/models/llama/llama_decoder_kernels.cu
View file @
2700abb3
...
@@ -16,13 +16,13 @@ struct res_norm_ops_t {};
...
@@ -16,13 +16,13 @@ struct res_norm_ops_t {};
template
<
typename
T
>
template
<
typename
T
>
struct
res_norm_t
{
struct
res_norm_t
{
res_norm_ops_t
<
T
>
f
;
res_norm_ops_t
<
T
>
f
;
__device__
uint4
addvec
(
const
uint4
&
a
,
const
uint4
&
b
,
float
&
accum
)
const
__device__
uint4
addvec
(
const
uint4
&
a
,
const
uint4
&
b
,
const
uint4
&
bias
,
float
&
accum
)
const
{
{
uint4
c
;
uint4
c
;
c
.
x
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
x
),
f
.
cast
(
b
.
x
),
accum
));
c
.
x
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
x
),
f
.
cast
(
b
.
x
),
f
.
cast
(
bias
.
x
),
accum
));
c
.
y
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
y
),
f
.
cast
(
b
.
y
),
accum
));
c
.
y
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
y
),
f
.
cast
(
b
.
y
),
f
.
cast
(
bias
.
y
),
accum
));
c
.
z
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
z
),
f
.
cast
(
b
.
z
),
accum
));
c
.
z
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
z
),
f
.
cast
(
b
.
z
),
f
.
cast
(
bias
.
z
),
accum
));
c
.
w
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
w
),
f
.
cast
(
b
.
w
),
accum
));
c
.
w
=
f
.
cast
(
f
.
add
(
f
.
cast
(
a
.
w
),
f
.
cast
(
b
.
w
),
f
.
cast
(
bias
.
w
),
accum
));
return
c
;
return
c
;
}
}
__device__
uint4
normvec
(
const
uint4
&
u
,
const
uint4
&
s
,
float
factor
)
const
__device__
uint4
normvec
(
const
uint4
&
u
,
const
uint4
&
s
,
float
factor
)
const
...
@@ -47,9 +47,9 @@ struct res_norm_ops_t<half> {
...
@@ -47,9 +47,9 @@ struct res_norm_ops_t<half> {
auto
y
=
__float22half2_rn
(
x
);
auto
y
=
__float22half2_rn
(
x
);
return
reinterpret_cast
<
uint
&>
(
y
);
return
reinterpret_cast
<
uint
&>
(
y
);
}
}
__device__
float2
add
(
const
float2
&
a
,
const
float2
&
b
,
float
&
accum
)
const
__device__
float2
add
(
const
float2
&
a
,
const
float2
&
b
,
const
float2
&
bias
,
float
&
accum
)
const
{
{
float2
c
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
float2
c
{
a
.
x
+
b
.
x
+
bias
.
x
,
a
.
y
+
b
.
y
+
bias
.
y
};
accum
+=
c
.
x
*
c
.
x
+
c
.
y
*
c
.
y
;
accum
+=
c
.
x
*
c
.
x
+
c
.
y
*
c
.
y
;
return
c
;
return
c
;
}
}
...
@@ -69,9 +69,9 @@ struct res_norm_ops_t<float> {
...
@@ -69,9 +69,9 @@ struct res_norm_ops_t<float> {
{
{
return
reinterpret_cast
<
const
uint
&>
(
x
);
return
reinterpret_cast
<
const
uint
&>
(
x
);
}
}
__device__
float
add
(
const
float
&
a
,
const
float
&
b
,
float
&
accum
)
const
__device__
float
add
(
const
float
&
a
,
const
float
&
b
,
const
float
&
bias
,
float
&
accum
)
const
{
{
float
c
=
a
+
b
;
float
c
=
a
+
b
+
bias
;
accum
+=
c
*
c
;
accum
+=
c
*
c
;
return
c
;
return
c
;
}
}
...
@@ -100,25 +100,32 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value)
...
@@ -100,25 +100,32 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value)
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
fusedAddResidualNorm
(
__global__
void
fusedAddBiasResidualNorm
(
T
*
__restrict__
r_data
,
T
*
__restrict__
r_data
,
T
*
__restrict__
x_data
,
const
T
*
__restrict__
scale
,
float
eps
,
int
batch_size
,
int
n_dims
)
T
*
__restrict__
x_data
,
const
T
*
__restrict__
bias
,
const
T
*
__restrict__
scale
,
float
eps
,
int
batch_size
,
int
n_dims
)
{
{
auto
block
=
cg
::
this_thread_block
();
auto
block
=
cg
::
this_thread_block
();
auto
grid
=
cg
::
this_grid
();
auto
grid
=
cg
::
this_grid
();
constexpr
int
PACK_DIM
=
sizeof
(
uint4
)
/
sizeof
(
T
);
constexpr
int
PACK_DIM
=
sizeof
(
uint4
)
/
sizeof
(
T
);
const
auto
b
=
grid
.
block_rank
();
const
auto
batch_idx
=
grid
.
block_rank
();
uint4
*
__restrict__
r_ptr
=
reinterpret_cast
<
uint4
*>
(
r_data
+
b
*
n_dims
);
uint4
*
__restrict__
r_ptr
=
reinterpret_cast
<
uint4
*>
(
r_data
+
batch_idx
*
n_dims
);
uint4
*
__restrict__
x_ptr
=
reinterpret_cast
<
uint4
*>
(
x_data
+
b
*
n_dims
);
uint4
*
__restrict__
x_ptr
=
reinterpret_cast
<
uint4
*>
(
x_data
+
batch_idx
*
n_dims
);
const
uint4
*
__restrict__
b_ptr
=
reinterpret_cast
<
const
uint4
*>
(
bias
);
res_norm_t
<
T
>
ops
;
res_norm_t
<
T
>
ops
;
float
thread_sum
{};
float
thread_sum
{};
for
(
auto
i
=
block
.
thread_rank
();
i
<
n_dims
/
PACK_DIM
;
i
+=
block
.
num_threads
())
{
for
(
auto
i
=
block
.
thread_rank
();
i
<
n_dims
/
PACK_DIM
;
i
+=
block
.
num_threads
())
{
auto
r
=
r_ptr
[
i
];
auto
r
=
r_ptr
[
i
];
auto
x
=
x_ptr
[
i
];
auto
x
=
x_ptr
[
i
];
r
=
ops
.
addvec
(
r
,
x
,
thread_sum
);
uint4
b
=
b_ptr
?
b_ptr
[
i
]
:
uint4
{};
r
=
ops
.
addvec
(
r
,
x
,
b
,
thread_sum
);
r_ptr
[
i
]
=
r
;
r_ptr
[
i
]
=
r
;
}
}
...
@@ -136,8 +143,8 @@ __global__ void fusedAddResidualNorm(
...
@@ -136,8 +143,8 @@ __global__ void fusedAddResidualNorm(
}
}
template
<
typename
T
>
template
<
typename
T
>
void
invokeFusedAddResidualRMSNorm
(
void
invokeFusedAdd
Bias
ResidualRMSNorm
(
T
*
residual
,
T
*
inout
,
const
T
*
scale
,
float
eps
,
int
batch_size
,
int
n_dims
,
cudaStream_t
stream
)
T
*
residual
,
T
*
inout
,
const
T
*
bias
,
const
T
*
scale
,
float
eps
,
int
batch_size
,
int
n_dims
,
cudaStream_t
stream
)
{
{
constexpr
int
PACK_DIM
=
sizeof
(
uint4
)
/
sizeof
(
T
);
constexpr
int
PACK_DIM
=
sizeof
(
uint4
)
/
sizeof
(
T
);
FT_CHECK
(
n_dims
%
PACK_DIM
==
0
);
FT_CHECK
(
n_dims
%
PACK_DIM
==
0
);
...
@@ -146,10 +153,12 @@ void invokeFusedAddResidualRMSNorm(
...
@@ -146,10 +153,12 @@ void invokeFusedAddResidualRMSNorm(
int
n_threads
=
(
n_pack
+
n_iter
-
1
)
/
n_iter
;
// adjust block size to avoid tail effect
int
n_threads
=
(
n_pack
+
n_iter
-
1
)
/
n_iter
;
// adjust block size to avoid tail effect
n_threads
=
(
n_threads
+
31
)
/
32
*
32
;
// round up to the nearest multiple of warp size
n_threads
=
(
n_threads
+
31
)
/
32
*
32
;
// round up to the nearest multiple of warp size
fusedAddResidualNorm
<<<
batch_size
,
n_threads
,
0
,
stream
>>>
(
residual
,
inout
,
scale
,
eps
,
batch_size
,
n_dims
);
fusedAddBiasResidualNorm
<<<
batch_size
,
n_threads
,
0
,
stream
>>>
(
residual
,
inout
,
bias
,
scale
,
eps
,
batch_size
,
n_dims
);
}
}
template
void
invokeFusedAddResidualRMSNorm
(
float
*
,
float
*
,
const
float
*
,
float
,
int
,
int
,
cudaStream_t
);
template
void
template
void
invokeFusedAddResidualRMSNorm
(
half
*
,
half
*
,
const
half
*
,
float
,
int
,
int
,
cudaStream_t
);
invokeFusedAddBiasResidualRMSNorm
(
float
*
,
float
*
,
const
float
*
,
const
float
*
,
float
,
int
,
int
,
cudaStream_t
);
template
void
invokeFusedAddBiasResidualRMSNorm
(
half
*
,
half
*
,
const
half
*
,
const
half
*
,
float
,
int
,
int
,
cudaStream_t
);
}
// namespace fastertransformer
}
// namespace fastertransformer
src/fastertransformer/models/llama/llama_decoder_kernels.h
View file @
2700abb3
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
namespace
fastertransformer
{
namespace
fastertransformer
{
template
<
typename
T
>
template
<
typename
T
>
void
invokeFusedAddResidualRMSNorm
(
void
invokeFusedAdd
Bias
ResidualRMSNorm
(
T
*
residual
,
T
*
inout
,
const
T
*
scale
,
float
eps
,
int
batch_size
,
int
n_dims
,
cudaStream_t
stream
);
T
*
residual
,
T
*
inout
,
const
T
*
bias
,
const
T
*
scale
,
float
eps
,
int
batch_size
,
int
n_dims
,
cudaStream_t
stream
);
}
// namespace fastertransformer
}
// namespace fastertransformer
src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "3rdparty/INIReader.h"
#include "3rdparty/INIReader.h"
...
@@ -127,6 +128,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
...
@@ -127,6 +128,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
use_context_fmha_
=
reader
.
GetInteger
(
"llama"
,
"use_context_fmha"
,
1
);
use_context_fmha_
=
reader
.
GetInteger
(
"llama"
,
"use_context_fmha"
,
1
);
cache_chunk_size_
=
reader
.
GetInteger
(
"llama"
,
"cache_chunk_size"
,
0
);
cache_chunk_size_
=
reader
.
GetInteger
(
"llama"
,
"cache_chunk_size"
,
0
);
prefix_cache_len_
=
reader
.
GetInteger
(
"llama"
,
"prefix_cache_len"
,
0
);
prefix_cache_len_
=
reader
.
GetInteger
(
"llama"
,
"prefix_cache_len"
,
0
);
attn_bias_
=
reader
.
GetInteger
(
"llama"
,
"attn_bias"
,
0
);
handleMissingParams
();
handleMissingParams
();
...
@@ -284,6 +286,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
...
@@ -284,6 +286,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
vocab_size_
,
vocab_size_
,
num_layer_
,
num_layer_
,
weight_type_
,
weight_type_
,
attn_bias_
,
tensor_para_size_
,
tensor_para_size_
,
tensor_para_rank
,
tensor_para_rank
,
prefix_cache_len_
);
prefix_cache_len_
);
...
@@ -297,14 +300,14 @@ std::string LlamaTritonModel<T>::toString()
...
@@ -297,14 +300,14 @@ std::string LlamaTritonModel<T>::toString()
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"Model: "
ss
<<
"Model: "
<<
"
\n
head_num: "
<<
head_num_
<<
"
\n
size_per_head: "
<<
size_per_head_
<<
"
\n
inter_size: "
<<
inter_size_
<<
"
\n
head_num: "
<<
head_num_
<<
"
\n
size_per_head: "
<<
size_per_head_
<<
"
\n
inter_size: "
<<
inter_size_
<<
"
\n
num_layer: "
<<
num_layer_
<<
"
\n
vocab_size: "
<<
vocab_size_
<<
"
\n
max_batch_size: "
<<
max_batch_size
_
<<
"
\n
num_layer: "
<<
num_layer_
<<
"
\n
vocab_size: "
<<
vocab_size_
<<
"
\n
attn_bias: "
<<
attn_bias
_
<<
"
\n
max_context_token_num: "
<<
max_context_token_num_
<<
"
\n
session_len: "
<<
session_len_
<<
"
\n
max_batch_size: "
<<
max_batch_size_
<<
"
\n
max_context_token_num: "
<<
max_context_token_num_
<<
"
\n
s
tep
_len
gth
: "
<<
s
tep
_len
gth
_
<<
"
\n
cache_max_entry_count: "
<<
cache_max_entry_count
_
<<
"
\n
s
ession
_len: "
<<
s
ession
_len_
<<
"
\n
step_length: "
<<
step_length
_
<<
"
\n
cache_
chunk_size: "
<<
cache_chunk_size_
<<
"
\n
use_context_fmha: "
<<
use_context_fmha
_
<<
"
\n
cache_
max_entry_count: "
<<
cache_max_entry_count_
<<
"
\n
cache_chunk_size: "
<<
cache_chunk_size
_
<<
"
\n
start_id: "
<<
start_id_
<<
"
\n
tensor_para_size: "
<<
tensor_para_size
_
<<
"
\n
use_context_fmha: "
<<
use_context_fmha_
<<
"
\n
start_id: "
<<
start_id
_
<<
"
\n
pipeline
_para_size: "
<<
pipeline
_para_size_
<<
"
\n
enable_custom_all_reduce: "
<<
enable_custom_all_reduc
e_
<<
"
\n
tensor
_para_size: "
<<
tensor
_para_size_
<<
"
\n
pipeline_para_size: "
<<
pipeline_para_siz
e_
<<
"
\n
model_name: "
<<
model_name_
<<
"
\n
prefix_cache_len: "
<<
prefix_cache_len
_
<<
"
\n
enable_custom_all_reduce: "
<<
enable_custom_all_reduce_
<<
"
\n
model_name: "
<<
model_name
_
<<
"
\n
model_dir: "
<<
model_dir_
<<
std
::
endl
;
<<
"
\n
prefix_cache_len: "
<<
prefix_cache_len_
<<
"
\n
model_dir: "
<<
model_dir_
<<
std
::
endl
;
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
src/fastertransformer/triton_backend/llama/LlamaTritonModel.h
View file @
2700abb3
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h
#pragma once
#pragma once
...
@@ -91,6 +92,7 @@ private:
...
@@ -91,6 +92,7 @@ private:
size_t
tensor_para_size_
;
size_t
tensor_para_size_
;
size_t
pipeline_para_size_
;
size_t
pipeline_para_size_
;
ft
::
WeightType
weight_type_
;
ft
::
WeightType
weight_type_
;
bool
attn_bias_
;
size_t
prefix_cache_len_
{};
size_t
prefix_cache_len_
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment