Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1e218749
Unverified
Commit
1e218749
authored
Feb 11, 2022
by
Masaki Kozuki
Committed by
GitHub
Feb 11, 2022
Browse files
cast for `-Wc++11-narrowing` (#1288)
parent
c8c00ef5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
csrc/mlp.cpp
csrc/mlp.cpp
+2
-2
No files found.
csrc/mlp.cpp
View file @
1e218749
...
@@ -62,7 +62,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
...
@@ -62,7 +62,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
// create output/workspace tensor
// create output/workspace tensor
auto
out
=
at
::
empty
({
batch_size
,
output_features
.
back
()},
inputs
[
0
].
type
());
auto
out
=
at
::
empty
({
batch_size
,
output_features
.
back
()},
inputs
[
0
].
type
());
auto
reserved_space
=
at
::
empty
({
reserved_size
},
inputs
[
0
].
type
());
auto
reserved_space
=
at
::
empty
({
static_cast
<
long
>
(
reserved_size
)
},
inputs
[
0
].
type
());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
inputs
[
0
].
type
());
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
inputs
[
0
].
type
());
...
@@ -135,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
...
@@ -135,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes
<
scalar_t
>
(
batch_size
,
num_layers
,
output_features
.
data
());
get_mlp_bp_workspace_in_bytes
<
scalar_t
>
(
batch_size
,
num_layers
,
output_features
.
data
());
// auto work_space = at::empty({work_size*4}, at::kByte);
// auto work_space = at::empty({work_size*4}, at::kByte);
auto
work_space
=
at
::
empty
({
work_size
/
sizeof
(
scalar_t
)},
inputs
[
0
].
type
());
auto
work_space
=
at
::
empty
({
static_cast
<
long
>
(
work_size
/
sizeof
(
scalar_t
)
)
},
inputs
[
0
].
type
());
auto
result
=
mlp_bp
<
scalar_t
>
(
auto
result
=
mlp_bp
<
scalar_t
>
(
inputs
[
0
].
data_ptr
<
scalar_t
>
(),
inputs
[
0
].
data_ptr
<
scalar_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment