Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0e06f621
Commit
0e06f621
authored
Jul 12, 2022
by
Geng Zhang
Committed by
Frank Lee
Jul 13, 2022
Browse files
[NFC] polish colossalai/nn/layer/parallel_sequence/_operation.py code style (#1266)
parent
c95e18cd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
49 deletions
+25
-49
colossalai/nn/layer/parallel_sequence/_operation.py
colossalai/nn/layer/parallel_sequence/_operation.py
+25
-49
No files found.
colossalai/nn/layer/parallel_sequence/_operation.py
View file @
0e06f621
...
...
@@ -19,24 +19,17 @@ class RingQK(torch.autograd.Function):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
sub_q
,
sub_k
,
batch_size
,
num_attention_heads
,
sub_seq_length
):
def
forward
(
ctx
,
sub_q
,
sub_k
,
batch_size
,
num_attention_heads
,
sub_seq_length
):
# save tensor for backward
ctx
.
save_for_backward
(
sub_q
,
sub_k
)
ctx
.
sub_seq_length
=
sub_seq_length
# create local segment of attention score
attention_score
=
torch
.
empty
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
sub_seq_length
*
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
),
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
()
)
attention_score
=
torch
.
empty
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
sub_seq_length
*
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
),
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
())
# compute local QK^T
part_a
=
torch
.
matmul
(
sub_q
,
sub_k
.
transpose
(
2
,
1
))
...
...
@@ -44,7 +37,7 @@ class RingQK(torch.autograd.Function):
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
start_idx
=
local_rank
*
sub_seq_length
end_idx
=
(
local_rank
+
1
)
*
sub_seq_length
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
...
...
@@ -63,19 +56,18 @@ class RingQK(torch.autograd.Function):
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
# calculate gradient of sub_k
grad_k
=
torch
.
matmul
(
grad_output
.
transpose
(
2
,
1
),
sub_q
)
grad_k
=
torch
.
matmul
(
grad_output
.
transpose
(
2
,
1
),
sub_q
)
dist
.
all_reduce
(
grad_k
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_k
=
grad_k
[:,
local_rank
*
ctx
.
sub_seq_length
:
(
local_rank
+
1
)
*
ctx
.
sub_seq_length
]
grad_k
=
grad_k
[:,
local_rank
*
ctx
.
sub_seq_length
:(
local_rank
+
1
)
*
ctx
.
sub_seq_length
]
grad_k
/=
local_world_size
# calculate gradient for sub_q
grad_q
=
torch
.
zeros_like
(
sub_q
,
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
)
grad_q
=
torch
.
zeros_like
(
sub_q
,
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
)
# compute with local sub_k
start_idx
,
end_idx
=
_calc_current_device_range
(
local_rank
,
ctx
.
sub_seq_length
)
...
...
@@ -85,7 +77,7 @@ class RingQK(torch.autograd.Function):
for
i
in
range
(
local_world_size
-
1
):
sub_k
=
ring_forward
(
sub_k
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
grad_q
/=
local_world_size
...
...
@@ -99,23 +91,16 @@ class RingAV(torch.autograd.Function):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
attention_score
,
sub_v
,
batch_size
,
num_attention_heads
,
attention_head_size
,
sub_seq_length
):
def
forward
(
ctx
,
attention_score
,
sub_v
,
batch_size
,
num_attention_heads
,
attention_head_size
,
sub_seq_length
):
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
local_start_idx
,
local_end_idx
=
_calc_current_device_range
(
local_rank
,
sub_seq_length
)
sub_attention_result
=
torch
.
zeros
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
attention_head_size
,
device
=
get_current_device
(),
dtype
=
attention_score
.
dtype
)
sub_attention_result
=
torch
.
zeros
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
attention_head_size
,
device
=
get_current_device
(),
dtype
=
attention_score
.
dtype
)
# save tensors for backward
ctx
.
save_for_backward
(
attention_score
,
sub_v
)
...
...
@@ -144,23 +129,16 @@ class RingAV(torch.autograd.Function):
attention_scores
,
sub_v
=
ctx
.
saved_tensors
# calculate gradient of v
grad_v
=
torch
.
matmul
(
attention_scores
.
transpose
(
2
,
1
),
grad_output
)
grad_v
=
torch
.
matmul
(
attention_scores
.
transpose
(
2
,
1
),
grad_output
)
dist
.
all_reduce
(
grad_v
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_v
=
grad_v
[:,
local_start_idx
:
local_end_idx
]
grad_v
/=
local_world_size
# calculate gradient for attention score
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_current_device
())
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_current_device
())
# compute with local sub_k
grad_attention_score
[:,
:,
local_start_idx
:
local_end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
grad_attention_score
[:,
:,
local_start_idx
:
local_end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
...
...
@@ -168,8 +146,6 @@ class RingAV(torch.autograd.Function):
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
# compute grad_q
grad_attention_score
[:,
:,
start_idx
:
end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
grad_attention_score
[:,
:,
start_idx
:
end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
return
grad_attention_score
,
grad_v
,
None
,
None
,
None
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment