Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
683b6e0e
Commit
683b6e0e
authored
Apr 10, 2019
by
Michael Carilli
Browse files
Quick kernel to clean up l2norm
parent
1a48b26b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
2 deletions
+19
-2
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+19
-2
No files found.
csrc/multi_tensor_l2norm_kernel.cu
View file @
683b6e0e
...
...
@@ -56,6 +56,20 @@ struct L2NormFunctor
}
};
__global__
void
cleanup
(
float
*
x
,
float
*
ret
)
{
__shared__
float
vals
[
512
];
float
val
=
0
;
if
(
threadIdx
.
x
<
320
)
val
=
x
[
threadIdx
.
x
];
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
final
);
}
at
::
Tensor
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -76,8 +90,11 @@ at::Tensor multi_tensor_l2norm_cuda(
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves
two
more small kernel launches, but will be negligible end to end.
// This involves
one
more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
return
output
.
sum
().
sqrt
();
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cleanup
<<<
1
,
512
,
0
,
stream
>>>
(
output
.
data
<
float
>
(),
ret
.
data
<
float
>
());
return
ret
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment