Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
442a2975
Commit
442a2975
authored
May 15, 2022
by
MaxT
Committed by
binmakeswell
May 17, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962)
parent
89e2767a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
12 deletions
+19
-12
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
+19
-12
No files found.
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
View file @
442a2975
...
...
@@ -19,21 +19,25 @@
template
<
typename
T
>
class
MultiHeadAttention
{
public:
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
);
virtual
~
MultiHeadAttention
();
void
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
);
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
);
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
);
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
);
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
);
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
)
{
_batch_size
=
batch_size
;
...
...
@@ -83,14 +87,17 @@ class MultiHeadAttention {
}
_qkv_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
*
3
);
_soft_out_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_soft_out_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_attn_o_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
// buffer size needed by attn bw
size_t
smem_size
=
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
size_t
smem_size
=
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
if
(
!
_shared_mem_ptr
)
{
cuda_free
(
_shared_mem_ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment