Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8cc99c29
"examples/pytorch/han/main.py" did not exist on "e16667bfb9ed446a7ed9cb63b5fd8b895748df94"
Commit
8cc99c29
authored
Mar 11, 2020
by
Thor Johnsen
Browse files
First commit
parent
5633f6db
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
536 additions
and
89 deletions
+536
-89
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+26
-0
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+510
-89
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
8cc99c29
...
...
@@ -2,8 +2,10 @@
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
...
...
@@ -11,6 +13,15 @@ void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::v
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void
strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
)
{
CHECK_INPUT
(
p_copy
);
fused_strided_check_finite
(
noop
,
p_copy
,
stride
,
clear_overflow_first
);
}
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
...
...
@@ -25,8 +36,23 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
adam_undo
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
fused_adam_undo_cuda
(
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
8cc99c29
...
...
@@ -21,6 +21,36 @@ typedef enum{
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
template
<
typename
GRAD_T
>
__global__
void
strided_check_finite_cuda_kernel
(
volatile
int
*
noop_gmem
,
GRAD_T
*
__restrict__
p_copy
,
const
size_t
tsize
,
int
stride
,
int
clear_overflow_first
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
)
*
stride
;
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
*
stride
;
if
(
clear_overflow_first
)
{
if
(
i
==
0
)
{
*
noop_gmem
=
0
;
}
__syncthreads
();
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
GRAD_T
pi
=
p_copy
[
j
];
if
(
!
isfinite
(
pi
))
{
*
noop_gmem
=
1
;
}
}
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
...
...
@@ -37,26 +67,148 @@ __global__ void adam_cuda_kernel(
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
bool
overflow
=
False
;
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
*
ILP
;
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
mi
[
ii
]
=
b1
*
mi
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
vi
[
ii
]
=
b2
*
vi
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
denom
=
sqrtf
(
v
i
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
float
update
=
(
mi
[
ii
]
/
denom
)
+
(
decay
*
pi
[
ii
]);
pi
[
ii
]
=
pi
[
ii
]
-
(
step_size
*
update
);
}
else
{
overflow
=
True
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
static_cast
<
GRAD_T
>
(
pi
[
ii
]);
}
}
}
if
(
p_copy
!=
NULL
)
{
__syncthreads
();
if
(
overflow
)
{
p_copy
[
0
]
=
INFINITY
;
}
}
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
__device__
void
adam_undo_cuda_kernel
(
T
*
__restrict__
p
,
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
int
j
=
j_start
+
i
*
ILP
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
*
ILP
;
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
pi
[
ii
]
=
(
pi
[
ii
]
+
step_size
*
(
mi
[
ii
]
/
denom
))
/
(
1.0
f
-
step_size
*
decay
);
mi
[
ii
]
=
(
mi
[
ii
]
-
(
1
-
b1
)
*
scaled_grad
)
/
b1
;
vi
[
ii
]
=
(
vi
[
ii
]
-
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
)
/
b2
;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vii
[
ii
]
=
vii
[
i
]
>=
0.0
f
?
vi
[
ii
]
:
0.0
f
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
*
ILP
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
}
}
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
...
...
@@ -93,59 +245,181 @@ struct AdamFunctor
}
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
bool
overflow
=
False
;
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
mi
[
ii
]
=
b1
*
mi
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
vi
[
ii
]
=
b2
*
vi
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
float
update
=
(
mi
[
ii
]
/
denom
)
+
(
decay
*
pi
[
ii
]);
pi
[
ii
]
=
pi
[
ii
]
-
(
step_size
*
update
);
}
else
{
overflow
=
True
;
}
}
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
static_cast
<
GRAD_T
>
(
pi
[
ii
]);
}
}
}
if
(
overflow
)
{
*
noop_gmem
=
1
;
}
}
};
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamUndoFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
denom
=
sqrtf
(
v
i
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
pi
[
ii
]
=
(
pi
[
ii
]
+
step_size
*
(
mi
[
ii
]
/
denom
))
/
(
1.0
f
-
step_size
*
decay
);
mi
[
ii
]
=
(
mi
[
ii
]
-
(
1
-
b1
)
*
scaled_grad
)
/
b1
;
vi
[
ii
]
=
(
vi
[
ii
]
-
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
)
/
b2
;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vii
[
ii
]
=
vii
[
i
]
>=
0.0
f
?
vi
[
ii
]
:
0.0
f
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
}
}
}
}
};
void
fused_strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
)
{
//Get tensor size
int
tsize
=
p_copy
.
numel
();
int
niter
=
(
tsize
+
stride
-
1
)
/
stride
;
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
niter
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_copy
),
"parameter tensor is too large to be indexed with int32"
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
noop
.
DATA_PTR
<
int
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
tsize
,
stride
,
clear_overflow_first
);
);
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
...
...
@@ -227,6 +501,84 @@ void fused_adam_cuda(
}
void
fused_adam_undo_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_undo_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_undo_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -262,48 +614,118 @@ void fused_adam_cuda_mt(
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
if
(
tl_sz
==
5
)
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
,
"expected tensor list of size 4"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_undo_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t
_0
,
scalar_t_0
>
(),
Adam
Undo
Functor
<
4
,
acc
scalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
...
...
@@ -311,15 +733,15 @@ void fused_adam_cuda_mt(
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_
undo_
cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
Adam
Undo
Functor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
...
...
@@ -327,8 +749,7 @@ void fused_adam_cuda_mt(
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
);
}
THCudaCheck
(
cudaGetLastError
());
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment