Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e57f5d0e
"vscode:/vscode.git/clone" did not exist on "531cb5c31070851edc431135fcaf8faad256681b"
Commit
e57f5d0e
authored
Apr 09, 2019
by
Michael Carilli
Browse files
Simple cut of the kernel in place
parent
03100f46
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
1 deletion
+60
-1
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+7
-0
csrc/type_shim.h
csrc/type_shim.h
+51
-0
setup.py
setup.py
+2
-1
No files found.
csrc/amp_C_frontend.cpp
View file @
e57f5d0e
...
...
@@ -14,9 +14,16 @@ void multi_tensor_axpby_cuda(
float
b
,
int
arg_to_check
);
at
::
Tensor
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_axpby"
,
&
multi_tensor_axpby_cuda
,
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
}
csrc/type_shim.h
View file @
e57f5d0e
...
...
@@ -31,3 +31,54 @@ struct TypeShim
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
,
bool
share_result
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
setup.py
View file @
e57f5d0e
...
...
@@ -71,7 +71,8 @@ if "--cuda_ext" in sys.argv:
CUDAExtension
(
name
=
'amp_C'
,
sources
=
[
'csrc/amp_C_frontend.cpp'
,
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_axpby_kernel.cu'
],
'csrc/multi_tensor_axpby_kernel.cu'
,
'csrc/multi_tensor_l2norm_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:[
'-lineinfo'
,
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment