Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a2799893
Commit
a2799893
authored
Mar 08, 2019
by
Simon Layton
Browse files
Handle fp16 weights case without forcing fp16 math
Incorrect types used in a few places
parent
75c8a97a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+11
-11
No files found.
csrc/multi_tensor_sgd_kernel.cu
View file @
a2799893
...
...
@@ -24,7 +24,7 @@
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
**/
template
<
int
N
,
typename
T_grad
,
typename
T
>
template
<
int
N
,
typename
T_grad
,
typename
T
_weight
>
struct
SGDFunctor
{
__device__
__forceinline__
void
operator
()(
...
...
@@ -53,24 +53,24 @@ struct SGDFunctor
T_grad
*
grad_in
=
(
T_grad
*
)
tl
.
addresses
[
0
][
tensor_loc
];
grad_in
+=
chunk_idx
*
chunk_size
;
T
*
weight_in
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
_weight
*
weight_in
=
(
T
_weight
*
)
tl
.
addresses
[
1
][
tensor_loc
];
weight_in
+=
chunk_idx
*
chunk_size
;
T
*
mom_in
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
_weight
*
mom_in
=
(
T
_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
mom_in
+=
chunk_idx
*
chunk_size
;
h
alf
*
model_weights_out
=
nullptr
;
at
::
H
alf
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
{
model_weights_out
=
(
h
alf
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
=
(
at
::
H
alf
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
// Non-divergent exit condition for the __syncthreads
T
incoming_grads
[
ILP
];
T
incoming_weights
[
ILP
];
T
incoming_moms
[
ILP
];
float
incoming_grads
[
ILP
];
float
incoming_weights
[
ILP
];
float
incoming_moms
[
ILP
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
...
...
@@ -83,9 +83,9 @@ struct SGDFunctor
incoming_moms
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
incoming_grads
[
ii
]
=
static_cast
<
T
>
(
grad_in
[
i
]);
incoming_weights
[
ii
]
=
static_cast
<
T
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
T
>
(
mom_in
[
i
]);
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
]);
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
}
// note for clarification to future michael:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment