Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
69142dbb
Commit
69142dbb
authored
Jan 31, 2020
by
Jeff Rasley
Browse files
add csr tensor
parent
e63b6b01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
0 deletions
+60
-0
deepspeed/pt/deepspeed_csr_tensor.py
deepspeed/pt/deepspeed_csr_tensor.py
+60
-0
No files found.
deepspeed/pt/deepspeed_csr_tensor.py
0 → 100644
View file @
69142dbb
"""
Copyright 2020 The Microsoft DeepSpeed Team
Implementation of a compressed sparse row (CSR) tensor. Similar in
functionality to TensorFlow's IndexedSlices implementation.
"""
import
torch
class
CSRTensor
(
object
):
""" Compressed Sparse Row (CSR) Tensor """
def
__init__
(
self
,
dense_tensor
=
None
):
self
.
orig_dense_tensor
=
dense_tensor
if
dense_tensor
is
not
None
:
result
=
torch
.
sum
(
dense_tensor
,
dim
=
1
)
self
.
indices
=
result
.
nonzero
().
flatten
()
self
.
values
=
dense_tensor
[
self
.
indices
]
self
.
dense_size
=
list
(
dense_tensor
.
size
())
else
:
self
.
indices
=
None
self
.
values
=
None
self
.
dense_size
=
None
@
staticmethod
def
type
():
return
"deepspeed.CSRTensor"
def
to_dense
(
self
):
it
=
self
.
indices
.
unsqueeze
(
1
)
full_indices
=
torch
.
cat
([
it
for
_
in
range
(
self
.
dense_size
[
1
])],
dim
=
1
)
return
self
.
values
.
new_zeros
(
self
.
dense_size
).
scatter_add_
(
0
,
full_indices
,
self
.
values
)
def
sparse_size
(
self
):
index_size
=
list
(
self
.
indices
.
size
())
index_size
=
index_size
[
0
]
value_size
=
list
(
self
.
values
.
size
())
value_size
=
value_size
[
0
]
*
value_size
[
1
]
dense_size
=
self
.
dense_size
[
0
]
*
self
.
dense_size
[
1
]
return
index_size
+
value_size
,
dense_size
def
add
(
self
,
b
):
assert
self
.
dense_size
==
b
.
dense_size
self
.
indices
=
torch
.
cat
([
self
.
indices
,
b
.
indices
])
self
.
values
=
torch
.
cat
([
self
.
values
,
b
.
values
])
def
__str__
(
self
):
sparse_size
,
dense_size
=
self
.
sparse_size
()
return
"DeepSpeed.CSRTensor(indices_size={}, values_size={}, "
\
"dense_size={}, device={}, reduction_factor={})"
.
format
(
self
.
indices
.
size
(),
self
.
values
.
size
(),
self
.
dense_size
,
self
.
indices
.
get_device
(),
dense_size
/
sparse_size
)
def
__repr__
(
self
):
return
self
.
__str__
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment