Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
42180bd9
Commit
42180bd9
authored
Mar 11, 2019
by
Michael Carilli
Browse files
Forward/backward compatibility around pytorch 3aeb78, to fix #191
parent
975ed322
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
125 additions
and
71 deletions
+125
-71
csrc/fused_adam_cuda.cpp
csrc/fused_adam_cuda.cpp
+0
-0
csrc/fused_adam_cuda_kernel.cu
csrc/fused_adam_cuda_kernel.cu
+11
-5
csrc/layer_norm_cuda.cpp
csrc/layer_norm_cuda.cpp
+0
-0
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+6
-2
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+2
-2
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+16
-6
csrc/type_shim.h
csrc/type_shim.h
+14
-0
csrc/welford.cu
csrc/welford.cu
+71
-51
docs/source/amp.rst
docs/source/amp.rst
+1
-1
setup.py
setup.py
+4
-4
No files found.
apex/optimizers/
csrc/fused_adam_cuda.cpp
→
csrc/fused_adam_cuda.cpp
View file @
42180bd9
File moved
apex/optimizers/
csrc/fused_adam_cuda_kernel.cu
→
csrc/fused_adam_cuda_kernel.cu
View file @
42180bd9
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "ATen/AccumulateType.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include <THC/THCGeneral.h>
#include "type_shim.h"
typedef
enum
{
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
ADAM_MODE_1
=
1
// eps outside square root
...
@@ -29,8 +31,8 @@ __global__ void adam_cuda_kernel(
...
@@ -29,8 +31,8 @@ __global__ void adam_cuda_kernel(
const
float
step_size
,
const
float
step_size
,
const
size_t
tsize
,
const
size_t
tsize
,
adamMode_t
mode
,
adamMode_t
mode
,
const
float
decay
)
{
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
...
@@ -67,7 +69,9 @@ void fused_adam_cuda(
...
@@ -67,7 +69,9 @@ void fused_adam_cuda(
int
step
,
int
step
,
int
mode
,
int
mode
,
int
bias_correction
,
int
bias_correction
,
float
decay
)
{
float
decay
)
{
// using namespace at;
//Get tensor size
//Get tensor size
int
tsize
=
p
.
numel
();
int
tsize
=
p
.
numel
();
...
@@ -91,7 +95,8 @@ void fused_adam_cuda(
...
@@ -91,7 +95,8 @@ void fused_adam_cuda(
//all other values should be fp32 for half gradients
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
//dispatch is done on the gradient type
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
using
namespace
at
;
// prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
g
.
type
()),
"adam_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
accscalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p
.
data
<
accscalar_t
>
(),
...
@@ -109,7 +114,8 @@ void fused_adam_cuda(
...
@@ -109,7 +114,8 @@ void fused_adam_cuda(
decay
);
decay
);
}));
}));
}
else
{
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES
(
TypeShim
(
g
.
type
()),
"adam_cuda_kernel"
,
([
&
]
{
adam_cuda_kernel
<
scalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
scalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
p
.
data
<
scalar_t
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
NULL
,
//don't output p_copy for fp32, it's wasted write
...
...
apex/normalization/
csrc/layer_norm_cuda.cpp
→
csrc/layer_norm_cuda.cpp
View file @
42180bd9
File moved
apex/normalization/
csrc/layer_norm_cuda_kernel.cu
→
csrc/layer_norm_cuda_kernel.cu
View file @
42180bd9
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
void
cuWelfordOnlineSum
(
const
U
curr
,
const
U
curr
,
...
@@ -675,7 +677,8 @@ void cuda_layer_norm(
...
@@ -675,7 +677,8 @@ void cuda_layer_norm(
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
)
double
epsilon
)
{
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
->
type
(),
"layer_norm_cuda_kernel"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
->
type
()),
"layer_norm_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
data
<
scalar_t
>
(),
output
->
data
<
scalar_t
>
(),
...
@@ -772,7 +775,8 @@ void cuda_layer_norm_gradient(
...
@@ -772,7 +775,8 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
->
type
(),
"cuComputeGradInput"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
->
type
()),
"cuComputeGradInput"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
data
<
scalar_t
>
(),
dout
->
data
<
scalar_t
>
(),
...
...
csrc/multi_tensor_apply.cuh
View file @
42180bd9
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
constexpr
int
depth_to_max_tensors
[
5
]
=
{
110
,
64
,
48
,
36
,
30
};
constexpr
int
depth_to_max_tensors
[
5
]
=
{
110
,
64
,
48
,
36
,
30
};
constexpr
int
depth_to_max_blocks
[
5
]
=
{
320
,
320
,
320
,
320
,
320
};
constexpr
int
depth_to_max_blocks
[
5
]
=
{
320
,
320
,
320
,
320
,
320
};
template
<
int
n
>
struct
TensorList
template
<
int
n
>
struct
TensorList
Metadata
{
{
void
*
addresses
[
n
][
depth_to_max_tensors
[
n
-
1
]];
void
*
addresses
[
n
][
depth_to_max_tensors
[
n
-
1
]];
int
sizes
[
depth_to_max_tensors
[
n
-
1
]];
int
sizes
[
depth_to_max_tensors
[
n
-
1
]];
...
@@ -62,7 +62,7 @@ void multi_tensor_apply(
...
@@ -62,7 +62,7 @@ void multi_tensor_apply(
int
ntensors
=
tensor_lists
[
0
].
size
();
int
ntensors
=
tensor_lists
[
0
].
size
();
TensorList
<
depth
>
tl
;
TensorList
Metadata
<
depth
>
tl
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
42180bd9
...
@@ -2,9 +2,15 @@
...
@@ -2,9 +2,15 @@
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
...
@@ -15,7 +21,7 @@ struct ScaleFunctor
...
@@ -15,7 +21,7 @@ struct ScaleFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorList
<
2
>&
tl
,
TensorList
Metadata
<
2
>&
tl
,
float
scale
)
float
scale
)
{
{
__shared__
int
noop_smem
;
__shared__
int
noop_smem
;
...
@@ -87,15 +93,17 @@ void multi_tensor_scale_cuda(
...
@@ -87,15 +93,17 @@ void multi_tensor_scale_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
)
float
scale
)
{
{
using
namespace
at
;
// The output (downscaled) type is always float.
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
0
][
0
].
type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
tensor_lists
[
0
][
0
].
type
()),
"multi_tensor_scale_cuda"
,
"multi_tensor_scale_cuda"
,
[
&
]
[
&
]
{
{
// using accscalar_t = acc_type<scalar_t, true>;
// using accscalar_t = acc_type<scalar_t, true>;
switch
(
tensor_lists
[
1
][
0
].
type
().
scalar
T
ype
())
switch
(
tensor_lists
[
1
][
0
].
scalar
_t
ype
())
{
{
case
at
::
ScalarType
::
Half
:
case
at
::
ScalarType
::
Half
:
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
...
@@ -116,8 +124,10 @@ void multi_tensor_scale_cuda(
...
@@ -116,8 +124,10 @@ void multi_tensor_scale_cuda(
scale
);
scale
);
break
;
break
;
default:
default:
AT_ERROR
(
"multi_tensor_scale_cuda not implemented for output type = "
,
std
::
stringstream
ss
;
tensor_lists
[
1
][
0
].
type
().
toString
());
ss
<<
"multi_tensor_scale_cuda not implemented for output type = "
<<
tensor_lists
[
1
][
0
].
dtype
();
AT_ERROR
(
ss
.
str
().
c_str
());
}
}
});
});
...
...
csrc/type_shim.h
0 → 100644
View file @
42180bd9
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
struct
TypeShim
{
const
at
::
Type
&
payload
;
TypeShim
(
const
at
::
Type
&
type
)
:
payload
(
type
)
{}
// Enable trivial conversion to a const at::Type& for pre-3aeb78
operator
const
at
::
Type
&
(){
return
payload
;
};
// Enable dispatch switch statements to take *this directly for post-3aeb78
operator
at
::
ScalarType
(){
return
payload
.
scalarType
();
};
};
csrc/welford.cu
View file @
42180bd9
...
@@ -3,13 +3,13 @@
...
@@ -3,13 +3,13 @@
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <vector>
#include <vector>
#include "type_shim.h"
__device__
__forceinline__
int
lastpow2
(
int
n
)
__device__
__forceinline__
int
lastpow2
(
int
n
)
{
{
...
@@ -844,16 +844,19 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
...
@@ -844,16 +844,19 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"welford_mean_var_kernel"
,
([
&
]
{
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
namespace
at
;
welford_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"welford_mean_var_kernel"
,
([
&
]
{
input
.
data
<
scalar_t
>
(),
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
out_mean
.
data
<
accscalar_t
>
(),
welford_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_var_biased
.
data
<
accscalar_t
>
(),
input
.
data
<
scalar_t
>
(),
batch_size
,
out_mean
.
data
<
accscalar_t
>
(),
feature_size
,
out_var_biased
.
data
<
accscalar_t
>
(),
space_size
);
batch_size
,
}));
feature_size
,
space_size
);
}));
}
return
{
out_mean
,
out_var_biased
};
return
{
out_mean
,
out_var_biased
};
}
}
...
@@ -881,7 +884,8 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -881,7 +884,8 @@ at::Tensor batchnorm_forward_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
...
@@ -898,7 +902,8 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -898,7 +902,8 @@ at::Tensor batchnorm_forward_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
...
@@ -950,7 +955,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -950,7 +955,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
...
@@ -970,7 +976,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -970,7 +976,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
...
@@ -1017,7 +1024,8 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -1017,7 +1024,8 @@ at::Tensor batchnorm_backward_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
...
@@ -1036,7 +1044,8 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -1036,7 +1044,8 @@ at::Tensor batchnorm_backward_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
...
@@ -1072,18 +1081,21 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
...
@@ -1072,18 +1081,21 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
mean_feature_nodes
.
type
(),
"welford_parallel_kernel"
,
([
&
]
{
{
welford_kernel_parallel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
using
namespace
at
;
mean_feature_nodes
.
data
<
scalar_t
>
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
mean_feature_nodes
.
type
()),
"welford_parallel_kernel"
,
([
&
]
{
var_biased
.
data
<
scalar_t
>
(),
welford_kernel_parallel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_mean
.
data
<
scalar_t
>
(),
mean_feature_nodes
.
data
<
scalar_t
>
(),
out_var
.
data
<
scalar_t
>
(),
var_biased
.
data
<
scalar_t
>
(),
inv_std
.
data
<
scalar_t
>
(),
out_mean
.
data
<
scalar_t
>
(),
world_size
,
out_var
.
data
<
scalar_t
>
(),
feature_size
,
inv_std
.
data
<
scalar_t
>
(),
eps
,
world_size
,
numel
);
feature_size
,
}));
eps
,
numel
);
}));
}
return
{
out_mean
,
out_var
,
inv_std
};
return
{
out_mean
,
out_var
,
inv_std
};
}
}
...
@@ -1111,21 +1123,23 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
...
@@ -1111,21 +1123,23 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"welford_mean_var_c_last"
,
([
&
]
{
using
namespace
at
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"welford_mean_var_c_last"
,
([
&
]
{
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
welford_kernel_c_last
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
<<<
grid
,
block
,
0
,
stream
>>>
(
welford_kernel_c_last
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
input
.
data
<
scalar_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
out_mean
.
data
<
accscalar_t
>
(),
input
.
data
<
scalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
out_mean
.
data
<
accscalar_t
>
(),
staging_data_ptr
,
out_var_biased
.
data
<
accscalar_t
>
(),
semaphores_ptr
,
staging_data_ptr
,
reduction_size
,
semaphores_ptr
,
stride
);
reduction_size
,
}));
stride
);
}));
}
return
{
out_mean
,
out_var_biased
};
return
{
out_mean
,
out_var_biased
};
}
}
...
@@ -1149,7 +1163,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1149,7 +1163,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -1167,7 +1182,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1167,7 +1182,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -1222,7 +1238,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1222,7 +1238,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
...
@@ -1246,7 +1263,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1246,7 +1263,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
...
@@ -1291,7 +1309,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1291,7 +1309,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -1311,7 +1330,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1311,7 +1330,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
TypeShim
(
input
.
type
()),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
...
...
docs/source/amp.rst
View file @
42180bd9
...
@@ -26,7 +26,7 @@ override the defaults established by the ``opt_level``.
...
@@ -26,7 +26,7 @@ override the defaults established by the ``opt_level``.
Example::
Example::
# Declare model and optimizer as usual
# Declare model and optimizer as usual
, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
...
...
setup.py
View file @
42180bd9
...
@@ -55,8 +55,8 @@ if "--cuda_ext" in sys.argv:
...
@@ -55,8 +55,8 @@ if "--cuda_ext" in sys.argv:
'--use_fast_math'
]}))
'--use_fast_math'
]}))
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_adam_cuda'
,
CUDAExtension
(
name
=
'fused_adam_cuda'
,
sources
=
[
'
apex/optimizers/
csrc/fused_adam_cuda.cpp'
,
sources
=
[
'csrc/fused_adam_cuda.cpp'
,
'
apex/optimizers/
csrc/fused_adam_cuda_kernel.cu'
],
'csrc/fused_adam_cuda_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,],
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
'--use_fast_math'
]}))
'--use_fast_math'
]}))
...
@@ -66,8 +66,8 @@ if "--cuda_ext" in sys.argv:
...
@@ -66,8 +66,8 @@ if "--cuda_ext" in sys.argv:
'csrc/welford.cu'
]))
'csrc/welford.cu'
]))
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_layer_norm_cuda'
,
CUDAExtension
(
name
=
'fused_layer_norm_cuda'
,
sources
=
[
'
apex/normalization/
csrc/layer_norm_cuda.cpp'
,
sources
=
[
'csrc/layer_norm_cuda.cpp'
,
'
apex/normalization/
csrc/layer_norm_cuda_kernel.cu'
],
'csrc/layer_norm_cuda_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_ge_1_1
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_ge_1_1
,
'nvcc'
:[
'-maxrregcount=50'
,
'nvcc'
:[
'-maxrregcount=50'
,
'-O3'
,
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment