Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7bd0bee8
Unverified
Commit
7bd0bee8
authored
May 04, 2023
by
Hongxin Liu
Committed by
GitHub
May 04, 2023
Browse files
[chat] add opt attn kernel (#3655)
* [chat] add opt attn kernel * [chat] disable xformer during fwd
parent
1a60dc07
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
0 deletions
+117
-0
applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
+6
-0
applications/Chat/coati/kernels/__init__.py
applications/Chat/coati/kernels/__init__.py
+6
-0
applications/Chat/coati/kernels/opt_attn.py
applications/Chat/coati/kernels/opt_attn.py
+87
-0
applications/Chat/coati/kernels/wrapper.py
applications/Chat/coati/kernels/wrapper.py
+18
-0
No files found.
applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
View file @
7bd0bee8
...
...
@@ -101,6 +101,11 @@ def main(args):
initial_model
=
deepcopy
(
actor
).
cuda
().
half
()
reward_model
=
RewardModel
(
deepcopy
(
critic
.
model
),
deepcopy
(
critic
.
value_head
)).
cuda
().
half
()
if
args
.
use_kernels
:
from
coati.kernels
import
convert_to_xformer_model
actor
,
critic
,
initial_model
,
reward_model
=
map
(
convert_to_xformer_model
,
(
actor
,
critic
,
initial_model
,
reward_model
))
actor_numel
=
get_model_numel
(
actor
,
strategy
)
critic_numel
=
get_model_numel
(
critic
,
strategy
)
initial_model_numel
=
get_model_numel
(
initial_model
,
strategy
)
...
...
@@ -184,5 +189,6 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--cuda_mem_frac'
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
'--offload_inference_models'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
'--use_kernels'
,
action
=
'store_true'
,
default
=
False
)
args
=
parser
.
parse_args
()
main
(
args
)
applications/Chat/coati/kernels/__init__.py
0 → 100644
View file @
7bd0bee8
from
.wrapper
import
convert_to_xformer_model
,
recover_from_xformer_model
__all__
=
[
'convert_to_xformer_model'
,
'recover_from_xformer_model'
,
]
applications/Chat/coati/kernels/opt_attn.py
0 → 100644
View file @
7bd0bee8
from
typing
import
Optional
,
Tuple
import
torch
import
xformers.ops
as
xops
from
torch
import
Tensor
from
transformers.models.opt.modeling_opt
import
OPTAttention
# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py
class
XOPTAttention
(
OPTAttention
):
# def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
def
forward
(
self
,
hidden_states
:
Tensor
,
key_value_states
:
Optional
[
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tensor
]
=
None
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
layer_head_mask
:
Optional
[
Tensor
]
=
None
,
output_attentions
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
],
Optional
[
Tuple
[
Tensor
]]]:
if
not
self
.
training
:
return
super
().
forward
(
hidden_states
,
key_value_states
,
past_key_value
,
attention_mask
,
layer_head_mask
,
output_attentions
)
"""Input shape: Batch x Time x Channel"""
assert
layer_head_mask
is
None
,
'Xformers attention does not support layer_head_mask'
assert
not
output_attentions
,
'Xformers attention does not support output_attentions'
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention
=
key_value_states
is
not
None
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
query_states
=
self
.
q_proj
(
hidden_states
)
# get key, value proj
if
is_cross_attention
and
past_key_value
is
not
None
:
# reuse k,v, cross_attentions
key_states
=
past_key_value
[
0
]
value_states
=
past_key_value
[
1
]
elif
is_cross_attention
:
# cross_attentions
key_states
=
self
.
_shape
(
self
.
k_proj
(
key_value_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
key_value_states
),
-
1
,
bsz
)
elif
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
else
:
# self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
if
self
.
is_decoder
:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value
=
(
key_states
,
value_states
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
attn_output
=
xops
.
memory_efficient_attention
(
query_states
,
key_states
,
value_states
,
attn_bias
=
xops
.
LowerTriangularMask
(),
p
=
self
.
dropout
if
self
.
training
else
0.0
,
scale
=
self
.
scaling
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
self
.
embed_dim
)
attn_output
=
self
.
out_proj
(
attn_output
)
attn_weights_reshaped
=
None
return
attn_output
,
attn_weights_reshaped
,
past_key_value
applications/Chat/coati/kernels/wrapper.py
0 → 100644
View file @
7bd0bee8
import
torch.nn
as
nn
from
transformers.models.opt.modeling_opt
import
OPTAttention
from
.opt_attn
import
XOPTAttention
def
convert_to_xformer_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
OPTAttention
):
module
.
__class__
=
XOPTAttention
return
model
def
recover_from_xformer_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
XOPTAttention
):
module
.
__class__
=
OPTAttention
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment