Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2619f1cb
"vscode:/vscode.git/clone" did not exist on "17e355f70267e0e01f7c4355a75d46b76f55b5aa"
Commit
2619f1cb
authored
May 07, 2020
by
Thor Johnsen
Browse files
Resolve merge conflict
parent
91a5a87e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
379 additions
and
538 deletions
+379
-538
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+0
-5
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+379
-533
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
2619f1cb
...
...
@@ -8,13 +8,10 @@ void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_maybe_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
maybe_cast_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
void
update_step_and_loss_scaler_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
step_and_loss_scaler
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...
...
@@ -84,8 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"reversible_adam"
,
&
reversible_adam
,
"Reversible Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo_mt"
,
&
fused_maybe_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_cast"
,
&
maybe_cast
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_cast_mt"
,
&
maybe_cast_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"update_step_and_loss_scaler"
,
&
update_step_and_loss_scaler_cuda
,
"Update step and loss scaler."
);
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
2619f1cb
...
...
@@ -14,6 +14,17 @@
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
#include "type_shim.h"
typedef
enum
{
...
...
@@ -21,6 +32,359 @@ typedef enum{
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
p
)
&&
is_aligned
(
m
)
&&
is_aligned
(
v
)
&&
is_aligned
(
g
)
&&
is_aligned
(
p_copy
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
GRAD_T
tmp_g
[
ILP
];
load_store
(
incoming_p
,
p
,
0
,
i_start
);
load_store
(
incoming_m
,
m
,
0
,
i_start
);
load_store
(
incoming_v
,
v
,
0
,
i_start
);
load_store
(
tmp_g
,
g
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_g
[
ii
]
=
static_cast
<
T
>
(
tmp_g
[
ii
]);
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
incoming_m
[
ii
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
incoming_v
[
ii
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
incoming_v
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
incoming_v
[
ii
])
+
eps
;
float
update
=
(
incoming_m
[
ii
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
incoming_p
[
ii
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
tmp_g
[
ii
]
=
static_cast
<
GRAD_T
>
(
incoming_p
[
ii
]);
}
load_store
(
p
,
incoming_p
,
i_start
,
0
);
load_store
(
m
,
incoming_m
,
i_start
,
0
);
load_store
(
v
,
incoming_v
,
i_start
,
0
);
if
(
DEPTH
==
5
)
load_store
(
p_copy
,
tmp_g
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
}
};
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
if
(
tl_sz
==
5
)
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
}
template
<
typename
FROM_T
,
typename
TO_T
>
__device__
void
convert
(
const
FROM_T
vi
,
TO_T
&
vo
)
{
...
...
@@ -202,44 +566,6 @@ __global__ void maybe_cast_kernel(
}
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
reversible_adam_cuda_kernel
(
T
*
__restrict__
p
,
...
...
@@ -402,266 +728,53 @@ __global__ void maybe_adam_undo_cuda_kernel(
}
}
}
}
__global__
void
update_step_and_loss_scaler_kernel
(
volatile
int
*
overflow_flag
,
double
*
__restrict__
step_and_loss_scaler_vec
)
{
// 0 : step
// 1 : iter
// 2 : loss_scale
// 3 : last_overflow_iter
// 4 : scale_factor
// 5 : scale_window
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
threadIdx
.
z
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
double
loss_scale
=
step_and_loss_scaler_vec
[
2
];
double
scale_factor
=
step_and_loss_scaler_vec
[
4
];
int
iter
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
1
]);
int
last_overflow_iter
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
3
]);
if
(
*
overflow_flag
==
0
)
{
// increase step
step_and_loss_scaler_vec
[
0
]
+=
1.0
;
// maybe increase loss scaler
int
scale_window
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
5
]);
if
(((
iter
-
last_overflow_iter
)
%
scale_window
)
==
0
)
{
step_and_loss_scaler_vec
[
2
]
=
loss_scale
*
scale_factor
;
}
}
else
{
step_and_loss_scaler_vec
[
2
]
=
loss_scale
/
scale_factor
;
step_and_loss_scaler_vec
[
3
]
=
static_cast
<
double
>
(
iter
);
}
step_and_loss_scaler_vec
[
1
]
+=
1.0
;
}
}
template
<
int
DEPTH
,
typename
FROM_T
,
typename
TO_T
>
struct
MaybeCastFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
)
{
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p_in
+=
chunk_idx
*
chunk_size
;
TO_T
*
p_out
=
(
TO_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p_out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
FROM_T
(
0
);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
}
};
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
mi
[
ii
]
=
b1
*
mi
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
vi
[
ii
]
=
b2
*
vi
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
float
update
=
(
mi
[
ii
]
/
denom
)
+
(
decay
*
pi
[
ii
]);
pi
[
ii
]
=
pi
[
ii
]
-
(
step_size
*
update
);
}
else
{
overflow
=
true
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
static_cast
<
GRAD_T
>
(
pi
[
ii
]);
}
}
}
if
(
overflow
)
{
*
overflow_flag
=
1
;
}
}
};
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD
_T
>
struct
Maybe
AdamUndo
Functor
template
<
int
DEPTH
,
typename
FROM_
T
,
typename
TO
_T
>
struct
Maybe
Cast
Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
TensorListMetadata
<
DEPTH
>&
tl
)
{
// Skip Adam undo when overflow flag is NOT set
if
(
overflow_flag
&&
*
overflow_flag
==
0
)
return
;
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p_in
+=
chunk_idx
*
chunk_size
;
TO_T
*
p_out
=
(
TO_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p_out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
pi
[
ii
]
=
FROM_T
(
0
);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
pi
[
ii
]
=
(
pi
[
ii
]
+
step_size
*
(
mi
[
ii
]
/
denom
))
/
(
1.0
f
-
step_size
*
decay
);
mi
[
ii
]
=
(
mi
[
ii
]
-
(
1
-
b1
)
*
scaled_grad
)
/
b1
;
vi
[
ii
]
=
(
vi
[
ii
]
-
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
)
/
b2
;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vi
[
ii
]
=
vi
[
ii
]
>=
0.0
f
?
vi
[
ii
]
:
0.0
f
;
}
convert
(
pi
[
ii
],
po
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
p_out
[
j
]
=
po
[
ii
];
}
}
}
...
...
@@ -696,86 +809,6 @@ void fused_strided_check_finite(
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_reversible_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
...
...
@@ -923,18 +956,6 @@ void maybe_cast_cuda_mt(
THCudaCheck
(
cudaGetLastError
());
}
void
update_step_and_loss_scaler_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
step_and_loss_scaler
)
{
AT_ASSERTM
(
step_and_loss_scaler
.
numel
()
==
6
,
"step_and_loss_scaler must have 6 elements"
);
AT_ASSERTM
(
step_and_loss_scaler
.
scalar_type
()
==
at
::
ScalarType
::
Double
,
"expected step_and_loss_scaler to be a double tensor"
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
update_step_and_loss_scaler_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
overflow_flag
.
DATA_PTR
<
int
>
(),
step_and_loss_scaler
.
DATA_PTR
<
double
>
());
}
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
...
...
@@ -1013,178 +1034,3 @@ void fused_maybe_adam_undo_cuda(
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
if
(
tl_sz
==
5
)
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_maybe_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
,
"expected tensor list of size 4"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"maybe_adam_undo_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
MaybeAdamUndoFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"maybe_adam_undo_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
MaybeAdamUndoFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment