Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
63670fd8
Commit
63670fd8
authored
Dec 27, 2022
by
Tri Dao
Browse files
Implement generation for GPT
parent
9d797d88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
242 additions
and
39 deletions
+242
-39
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+13
-5
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+82
-24
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+64
-0
tests/models/test_gpt.py
tests/models/test_gpt.py
+0
-10
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+83
-0
No files found.
flash_attn/models/gpt.py
View file @
63670fd8
...
@@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
...
@@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_sequence_parallel_params
from
flash_attn.utils.distributed
import
sync_sequence_parallel_params
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
@@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
if
process_group
is
None
else
{})
if
process_group
is
None
else
{})
parallel_kwargs
=
{
'process_group'
:
process_group
}
if
process_group
is
not
None
else
{}
parallel_kwargs
=
{
'process_group'
:
process_group
}
if
process_group
is
not
None
else
{}
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
use_flash_attn
=
use_flash_attn
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
...
@@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
# Only the attention layers need to know the seqlen.
...
@@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
residual_in_fp32
=
True
residual_in_fp32
=
True
)
)
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
else
{})
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
else
{})
if
inference_params
is
not
None
:
mixer_kwargs
[
'inference_params'
]
=
inference_params
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
return
hidden_states
return
hidden_states
class
GPTLMHeadModel
(
GPTPreTrainedModel
):
class
GPTLMHeadModel
(
GPTPreTrainedModel
,
GenerationMixin
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
...
@@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
...
@@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
)
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
return
CausalLMOutput
(
logits
=
lm_logits
)
...
...
flash_attn/modules/mha.py
View file @
63670fd8
...
@@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
self
.
triton
=
triton
def
forward
(
self
,
qkv
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
...
@@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
max_seqlen: int. Maximum sequence length in the batch.
...
@@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
"""
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
assert
qkv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
cu_seqlens
.
dtype
==
torch
.
int32
...
@@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
...
@@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
assert
isinstance
(
max_seqlen
,
int
)
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_qkvpacked_func
(
return
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
)
else
:
else
:
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
# Triton version doesn't support dropout
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
dropout_p
==
0
or
not
self
.
training
):
if
self
.
triton
and
(
self
.
dropout_p
==
0
or
not
self
.
training
):
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
else
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_seqlen
=
seqlen
max_seqlen
=
seqlen
...
@@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
device
=
qkv
.
device
)
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
return
output
...
@@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
...
@@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
self
.
triton
=
triton
def
forward
(
self
,
q
,
kv
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
q: The tensor containing the query. (B, Sq, H, D)
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
max_seqlen: int. Maximum sequence length in the batch of q.
...
@@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
...
@@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
"""
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
and
kv
.
is_cuda
assert
q
.
is_cuda
and
kv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
cu_seqlens
.
dtype
==
torch
.
int32
...
@@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
...
@@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
return
flash_attn_unpadded_kvpacked_func
(
return
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
)
else
:
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
self
.
triton
and
(
self
.
dropout_p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
dropout_p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
else
:
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
kv
=
rearrange
(
kv
,
'b s ... -> (b s) ...'
)
kv
=
rearrange
(
kv
,
'b s ... -> (b s) ...'
)
...
@@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
...
@@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
output
=
flash_attn_unpadded_kvpacked_func
(
output
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
return
output
...
@@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
...
@@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
key_padding_mask
=
None
):
def
forward
(
self
,
qkv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
False means to mask out. (B, S)
"""
"""
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
...
@@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
...
@@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
...
@@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
...
@@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
kv
,
key_padding_mask
=
None
):
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
q: The tensor containing the query. (B, Sq, H, D)
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
False means to mask out. (B, Sk)
"""
"""
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
seqlen_k
=
kv
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
k
,
v
=
kv
.
unbind
(
dim
=
2
)
k
,
v
=
kv
.
unbind
(
dim
=
2
)
...
@@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
...
@@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
...
@@ -280,7 +289,7 @@ class MHA(nn.Module):
...
@@ -280,7 +289,7 @@ class MHA(nn.Module):
"""
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
rotary_emb_scale_base
=
0
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
...
@@ -294,6 +303,7 @@ class MHA(nn.Module):
...
@@ -294,6 +303,7 @@ class MHA(nn.Module):
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
cross_attn
=
cross_attn
self
.
cross_attn
=
cross_attn
self
.
causal
=
causal
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
dwconv
=
dwconv
self
.
dwconv
=
dwconv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
use_flash_attn
=
use_flash_attn
...
@@ -315,6 +325,8 @@ class MHA(nn.Module):
...
@@ -315,6 +325,8 @@ class MHA(nn.Module):
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
else
partial
(
FusedDense
,
return_residual
=
True
))
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
@@ -323,7 +335,6 @@ class MHA(nn.Module):
...
@@ -323,7 +335,6 @@ class MHA(nn.Module):
if
self
.
dwconv
:
if
self
.
dwconv
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
groups
=
3
*
embed_dim
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
else
:
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
...
@@ -335,14 +346,41 @@ class MHA(nn.Module):
...
@@ -335,14 +346,41 @@ class MHA(nn.Module):
groups
=
embed_dim
)
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
2
*
embed_dim
)
groups
=
2
*
embed_dim
)
inner_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, 1, nheads, head_dim)
"""
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
inference_kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
inference_kv_cache
else
:
inference_kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
assert
batch_end
<=
inference_kv_cache
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
sequence_end
<=
inference_kv_cache
.
shape
[
1
]
# Copy key and values.
inference_kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
inference_kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
**
kwargs
):
inference_params
=
None
,
**
kwargs
):
"""
"""
Arguments:
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
@@ -355,6 +393,8 @@ class MHA(nn.Module):
...
@@ -355,6 +393,8 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
(batch, seqlen). Only applicable when not using FlashAttention.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
"""
if
cu_seqlens
is
not
None
:
if
cu_seqlens
is
not
None
:
assert
max_seqlen
is
not
None
assert
max_seqlen
is
not
None
...
@@ -366,6 +406,10 @@ class MHA(nn.Module):
...
@@ -366,6 +406,10 @@ class MHA(nn.Module):
assert
cu_seqlens
is
None
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
assert
not
self
.
use_flash_attn
if
inference_params
is
not
None
:
assert
key_padding_mask
is
None
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
not
self
.
dwconv
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
...
@@ -378,12 +422,22 @@ class MHA(nn.Module):
...
@@ -378,12 +422,22 @@ class MHA(nn.Module):
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
if
inference_params
is
None
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
self
.
rotary_emb_dim
>
0
:
if
not
self
.
checkpointing
:
qkv
=
self
.
rotary_emb
(
qkv
)
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
False
if
inference_params
.
sequence_len_offset
==
0
else
None
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
)
...
@@ -401,10 +455,14 @@ class MHA(nn.Module):
...
@@ -401,10 +455,14 @@ class MHA(nn.Module):
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
if
not
self
.
checkpointing
:
if
inference_params
is
None
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
kv
=
self
.
_update_kv_cache
(
kv
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
...
flash_attn/utils/generation.py
0 → 100644
View file @
63670fd8
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
dataclasses
import
dataclass
,
field
import
torch
from
einops
import
rearrange
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
@
dataclass
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len
:
int
max_batch_size
:
int
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
def
greedy_decode
(
input_ids
,
model
,
max_length
):
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
scores
=
[]
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
return
GreedySearchDecoderOnlyOutput
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
)
class
GenerationMixin
:
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
tests/models/test_gpt.py
View file @
63670fd8
...
@@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
...
@@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
def
get_hf_models
(
model_name
,
config
,
dtype
):
pretrained_state_dict
=
state_dict_from_pretrained
(
model_name
)
model_hf
=
GPT2LMHeadModelHF
(
config
)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model_hf
.
cuda
().
to
(
dtype
=
dtype
)
return
model_hf
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
# @pytest.mark.parametrize('model_name', ["gpt2"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_non_optimized
(
model_name
):
def
test_gpt2_non_optimized
(
model_name
):
...
...
tests/models/test_gpt_generation.py
0 → 100644
View file @
63670fd8
import
re
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
greedy_decode
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
max_length
=
30
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment