Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
22717a85
Unverified
Commit
22717a85
authored
Jun 22, 2022
by
ver217
Committed by
GitHub
Jun 22, 2022
Browse files
[tensor] add embedding bag op (#1156)
parent
ae861519
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
179 additions
and
0 deletions
+179
-0
colossalai/nn/_ops/__init__.py
colossalai/nn/_ops/__init__.py
+1
-0
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+122
-0
tests/test_tensor/test_embedding_bag_tp.py
tests/test_tensor/test_embedding_bag_tp.py
+56
-0
No files found.
colossalai/nn/_ops/__init__.py
View file @
22717a85
...
@@ -4,3 +4,4 @@ from .layernorm import colo_layernorm
...
@@ -4,3 +4,4 @@ from .layernorm import colo_layernorm
from
.loss
import
colo_cross_entropy
from
.loss
import
colo_cross_entropy
from
.embedding
import
colo_embedding
from
.embedding
import
colo_embedding
from
.addmm
import
colo_addmm
from
.addmm
import
colo_addmm
from
.embedding_bag
import
colo_embedding_bag
colossalai/nn/_ops/embedding_bag.py
0 → 100644
View file @
22717a85
import
torch.nn.functional
as
F
from
typing
import
Optional
from
torch
import
Tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
def
colo_embedding_bag_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
offsets
:
Optional
[
Tensor
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2
,
scale_grad_by_freq
:
bool
=
False
,
mode
:
str
=
"mean"
,
sparse
:
bool
=
False
,
per_sample_weights
:
Optional
[
Tensor
]
=
None
,
include_last_offset
:
bool
=
False
,
padding_idx
:
Optional
[
int
]
=
None
)
->
ColoTensor
:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
output_parallel
=
F
.
embedding_bag
(
input_tensor
,
weight
,
offsets
=
offsets
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
mode
=
mode
,
sparse
=
sparse
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
weight
.
spec
.
parallel_action
.
gather_out
:
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
return
output
def
colo_embedding_bag_1d
(
tp_mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
offsets
:
Optional
[
Tensor
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2
,
scale_grad_by_freq
:
bool
=
False
,
mode
:
str
=
"mean"
,
sparse
:
bool
=
False
,
per_sample_weights
:
Optional
[
Tensor
]
=
None
,
include_last_offset
:
bool
=
False
,
padding_idx
:
Optional
[
int
]
=
None
)
->
ColoTensor
:
assert
tp_mode
in
(
'col'
,)
funcs
=
{
'col'
:
colo_embedding_bag_1Dcol
}
return
funcs
[
tp_mode
](
input_tensor
,
weight
,
offsets
=
offsets
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
mode
=
mode
,
sparse
=
sparse
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
@
colo_op_impl
(
F
.
embedding_bag
)
def
colo_embedding_bag
(
input_tensor
:
GeneralTensor
,
weight
:
GeneralTensor
,
offsets
:
Optional
[
Tensor
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2
,
scale_grad_by_freq
:
bool
=
False
,
mode
:
str
=
"mean"
,
sparse
:
bool
=
False
,
per_sample_weights
:
Optional
[
Tensor
]
=
None
,
include_last_offset
:
bool
=
False
,
padding_idx
:
Optional
[
int
]
=
None
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
This method looks up an embedding table.
"""
input_tensor
,
weight
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
)))
# Handle differen parallel actions.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
weight
.
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
return
ColoTensor
.
from_torch_tensor
(
F
.
embedding_bag
(
input_tensor
,
weight
,
offsets
=
offsets
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
mode
=
mode
,
sparse
=
sparse
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
))
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_col
():
tp_mode
=
'col'
else
:
raise
NotImplementedError
return
colo_embedding_bag_1d
(
tp_mode
,
input_tensor
,
weight
,
offsets
=
offsets
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
mode
=
mode
,
sparse
=
sparse
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
else
:
raise
NotImplementedError
tests/test_tensor/test_embedding_bag_tp.py
0 → 100644
View file @
22717a85
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
,
distspec
,
ColoParameter
from
torch.nn
import
functional
as
F
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_col
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
model
=
torch
.
nn
.
EmbeddingBag
(
10
,
4
).
cuda
()
weight
=
ColoParameter
(
model
.
weight
.
clone
())
spec_init_func
(
weight
)
inputs
=
torch
.
tensor
([
1
,
2
,
4
,
5
,
4
,
3
,
2
,
9
]).
cuda
()
offsets
=
torch
.
tensor
([
0
,
4
]).
cuda
()
out
=
model
(
inputs
,
offsets
=
offsets
)
colo_out
=
F
.
embedding_bag
(
inputs
,
weight
,
offsets
=
offsets
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
init_1d_col
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_embedding_bag_1d
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_embedding_bag_1d
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment