Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
79ccfa43
Commit
79ccfa43
authored
Apr 05, 2022
by
encmps
Committed by
binmakeswell
Apr 06, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu code style (#667)
parent
e4bcff9b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
150 deletions
+112
-150
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
+112
-150
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
View file @
79ccfa43
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
...
@@ -8,14 +9,13 @@
...
@@ -8,14 +9,13 @@
#include <assert.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
typedef
enum
typedef
enum
{
{
ADAM_MODE_0
=
0
,
// L2 regularization mode
ADAM_MODE_0
=
0
,
// L2 regularization mode
ADAM_MODE_1
=
1
// Decoupled weight decay mode(AdamW)
ADAM_MODE_1
=
1
// Decoupled weight decay mode(AdamW)
}
adamMode_t
;
}
adamMode_t
;
...
@@ -23,21 +23,12 @@ typedef enum
...
@@ -23,21 +23,12 @@ typedef enum
using
MATH_T
=
float
;
using
MATH_T
=
float
;
template
<
typename
T_g
,
typename
T_p
>
template
<
typename
T_g
,
typename
T_p
>
struct
AdamFunctor
struct
AdamFunctor
{
{
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
volatile
int
*
noop_gmem
,
const
float
beta1
,
const
float
beta2
,
const
float
beta1_correction
,
TensorListMetadata
<
4
>
&
tl
,
const
float
beta2_correction
,
const
float
epsilon
,
const
float
lr
,
const
float
beta1
,
adamMode_t
mode
,
const
float
decay
)
{
const
float
beta2
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
const
float
lr
,
adamMode_t
mode
,
const
float
decay
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -65,27 +56,21 @@ struct AdamFunctor
...
@@ -65,27 +56,21 @@ struct AdamFunctor
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
{
{
r_g
[
ii
]
=
g
[
i
];
r_g
[
ii
]
=
g
[
i
];
r_p
[
ii
]
=
p
[
i
];
r_p
[
ii
]
=
p
[
i
];
r_m
[
ii
]
=
m
[
i
];
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
}
else
{
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
...
@@ -93,10 +78,8 @@ struct AdamFunctor
...
@@ -93,10 +78,8 @@ struct AdamFunctor
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
if
(
mode
==
ADAM_MODE_0
)
{
// L2
if
(
mode
==
ADAM_MODE_0
)
{
// L2
r_g
[
ii
]
=
r_g
[
ii
]
+
(
decay
*
r_p
[
ii
]);
r_g
[
ii
]
=
r_g
[
ii
]
+
(
decay
*
r_p
[
ii
]);
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
...
@@ -105,9 +88,7 @@ struct AdamFunctor
...
@@ -105,9 +88,7 @@ struct AdamFunctor
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
MATH_T
update
=
next_m_unbiased
/
denom
;
MATH_T
update
=
next_m_unbiased
/
denom
;
r_p
[
ii
]
=
r_p
[
ii
]
-
(
lr
*
update
);
r_p
[
ii
]
=
r_p
[
ii
]
-
(
lr
*
update
);
}
}
else
{
// weight decay
else
{
// weight decay
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
@@ -118,11 +99,9 @@ struct AdamFunctor
...
@@ -118,11 +99,9 @@ struct AdamFunctor
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
{
{
p
[
i
]
=
r_p
[
ii
];
p
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
v
[
i
]
=
r_v
[
ii
];
...
@@ -132,46 +111,29 @@ struct AdamFunctor
...
@@ -132,46 +111,29 @@ struct AdamFunctor
}
}
};
};
void
multi_tensor_adam_cuda
(
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
lr
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta2
,
const
int
step
,
const
int
mode
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
int
bias_correction
,
const
float
weight_decay
)
const
float
weight_decay
)
{
{
using
namespace
at
;
using
namespace
at
;
// Handle bias correction mode
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
if
(
bias_correction
==
1
)
{
{
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
}
DISPATCH_FLOAT_AND_HALF_FOR_G_P
(
DISPATCH_FLOAT_AND_HALF_FOR_G_P
(
tensor_lists
[
0
][
0
].
scalar_type
(),
tensor_lists
[
0
][
0
].
scalar_type
(),
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"adam"
,
"adam"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
BLOCK_SIZE
,
AdamFunctor
<
g_scalar_t_0
,
p_scalar_t_0
>
(),
beta1
,
chunk_size
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
noop_flag
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);)
tensor_lists
,
AdamFunctor
<
g_scalar_t_0
,
p_scalar_t_0
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);)
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment