Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8a32e428
Commit
8a32e428
authored
Jul 02, 2019
by
Michael Carilli
Browse files
Merging in master
parents
d9c887c2
18f2eaee
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
5 deletions
+14
-5
setup.py
setup.py
+3
-1
tests/L0/run_amp/test_multi_tensor_l2norm.py
tests/L0/run_amp/test_multi_tensor_l2norm.py
+11
-4
No files found.
setup.py
View file @
8a32e428
...
@@ -75,7 +75,9 @@ if "--cuda_ext" in sys.argv:
...
@@ -75,7 +75,9 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_sgd_kernel.cu'
,
'csrc/multi_tensor_sgd_kernel.cu'
,
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_axpby_kernel.cu'
,
'csrc/multi_tensor_axpby_kernel.cu'
,
'csrc/multi_tensor_l2norm_kernel.cu'
],
'csrc/multi_tensor_l2norm_kernel.cu'
,
'csrc/multi_tensor_lamb_stage_1.cu'
,
'csrc/multi_tensor_lamb_stage_2.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:[
'-lineinfo'
,
'nvcc'
:[
'-lineinfo'
,
'-O3'
,
'-O3'
,
...
...
tests/L0/run_amp/test_multi_tensor_l2norm.py
View file @
8a32e428
...
@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
...
@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
pass
pass
# The tensor creation here is written for convenience, not speed.
# The tensor creation here is written for convenience, not speed.
def
l2norm
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
):
def
l2norm
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
per_tensor
):
self
.
overflow_buf
.
zero_
()
self
.
overflow_buf
.
zero_
()
a
=
torch
.
cuda
.
FloatTensor
(
sizea
).
fill_
(
self
.
val
)
a
=
torch
.
cuda
.
FloatTensor
(
sizea
).
fill_
(
self
.
val
)
b
=
torch
.
cuda
.
FloatTensor
(
sizeb
).
fill_
(
self
.
val
)
b
=
torch
.
cuda
.
FloatTensor
(
sizeb
).
fill_
(
self
.
val
)
...
@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase):
...
@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for
i
in
range
(
repeat_tensors
):
for
i
in
range
(
repeat_tensors
):
in_list
+=
[
a
.
clone
().
to
(
in_type
),
b
.
clone
().
to
(
in_type
)]
in_list
+=
[
a
.
clone
().
to
(
in_type
),
b
.
clone
().
to
(
in_type
)]
if
per_tensor
:
norm
=
applier
(
multi_tensor_l2norm
,
self
.
overflow_buf
,
[
in_list
])
norm
,
norm_per_tensor
=
applier
(
multi_tensor_l2norm
,
self
.
overflow_buf
,
[
in_list
],
True
)
normab
=
torch
.
cat
((
a
.
norm
().
view
(
1
),
b
.
norm
().
view
(
1
)))
norm_per_tensor
=
norm_per_tensor
.
view
(
-
1
,
2
)
else
:
norm
,
_
=
applier
(
multi_tensor_l2norm
,
self
.
overflow_buf
,
[
in_list
],
True
)
reference
=
torch
.
cuda
.
FloatTensor
((
sizea
+
sizeb
)
*
repeat_tensors
).
fill_
(
self
.
val
).
norm
()
reference
=
torch
.
cuda
.
FloatTensor
((
sizea
+
sizeb
)
*
repeat_tensors
).
fill_
(
self
.
val
).
norm
()
self
.
assertTrue
(
torch
.
allclose
(
norm
,
reference
))
self
.
assertTrue
(
torch
.
allclose
(
norm
,
reference
))
if
per_tensor
:
self
.
assertTrue
(
torch
.
allclose
(
norm_per_tensor
,
normab
))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
...
@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase):
...
@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for
applier
in
appliers
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
repeat
in
repeat_tensors
:
for
in_type
in
(
torch
.
float32
,
torch
.
float16
):
for
in_type
in
(
torch
.
float32
,
torch
.
float16
):
self
.
l2norm
(
sizea
,
sizeb
,
applier
,
repeat
,
in_type
,
)
for
per_tensor
in
(
False
,
True
):
self
.
l2norm
(
sizea
,
sizeb
,
applier
,
repeat
,
in_type
,
per_tensor
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment