Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f3528d99
Commit
f3528d99
authored
May 03, 2019
by
Michael Carilli
Browse files
Converting dispatch macros in fused_adam_cuda_kernel.cu
parent
d0505433
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
15 deletions
+15
-15
csrc/fused_adam_cuda_kernel.cu
csrc/fused_adam_cuda_kernel.cu
+15
-15
No files found.
csrc/fused_adam_cuda_kernel.cu
View file @
f3528d99
...
@@ -203,7 +203,7 @@ void fused_adam_cuda(
...
@@ -203,7 +203,7 @@ void fused_adam_cuda(
tsize
,
tsize
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
)
)
;
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
...
@@ -261,14 +261,14 @@ void fused_adam_cuda_mt(
...
@@ -261,14 +261,14 @@ void fused_adam_cuda_mt(
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
if
(
tl_sz
==
5
)
{
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_
type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop_flag
,
noop_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t
>
(),
AdamFunctor
<
5
,
accscalar_t
,
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -276,16 +276,16 @@ void fused_adam_cuda_mt(
...
@@ -276,16 +276,16 @@ void fused_adam_cuda_mt(
step_size
,
step_size
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
})
);
);
}
else
{
}
else
{
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_
type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop_flag
,
noop_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t
>
(),
AdamFunctor
<
4
,
accscalar_t
,
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -293,17 +293,17 @@ void fused_adam_cuda_mt(
...
@@ -293,17 +293,17 @@ void fused_adam_cuda_mt(
step_size
,
step_size
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
})
);
);
}
}
}
else
{
}
else
{
if
(
tl_sz
==
5
)
{
if
(
tl_sz
==
5
)
{
AT_
DISPATCH_
FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
DISPATCH_
DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_
type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop_flag
,
noop_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t
,
scalar_t
>
(),
AdamFunctor
<
5
,
scalar_t
_0
,
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -311,15 +311,15 @@ void fused_adam_cuda_mt(
...
@@ -311,15 +311,15 @@ void fused_adam_cuda_mt(
step_size
,
step_size
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
})
);
);
}
else
{
}
else
{
AT_
DISPATCH_
FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
DISPATCH_
DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_
type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop_flag
,
noop_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t
,
scalar_t
>
(),
AdamFunctor
<
4
,
scalar_t
_0
,
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -327,7 +327,7 @@ void fused_adam_cuda_mt(
...
@@ -327,7 +327,7 @@ void fused_adam_cuda_mt(
step_size
,
step_size
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
})
);
);
}
}
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment