Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bc822902
Unverified
Commit
bc822902
authored
Aug 13, 2020
by
Jun Ru Anderson
Committed by
GitHub
Aug 13, 2020
Browse files
[refactor] remove type_shim.h (#33)
parent
571f5efa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
259 deletions
+19
-259
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
+19
-78
fairscale/clib/fused_adam_cuda/type_shim.h
fairscale/clib/fused_adam_cuda/type_shim.h
+0
-181
No files found.
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
View file @
bc822902
...
...
@@ -14,8 +14,6 @@
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
...
...
@@ -137,81 +135,24 @@ void fused_adam_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
if
(
tl_sz
==
5
)
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
AT_ASSERTM
(
tl_sz
==
4
,
"expected tensor lists of size 4"
);
// check that the model and gradients are FP32
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
);
AT_ASSERTM
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
);
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
AdamFunctor
<
4
,
float
,
float
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
decay
);
}
}
THCudaCheck
(
cudaGetLastError
());
}
fairscale/clib/fused_adam_cuda/type_shim.h
deleted
100644 → 0
View file @
571f5efa
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment