Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f28c0213
Commit
f28c0213
authored
May 16, 2022
by
binmakeswell
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978)
parent
18542b47
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
169 additions
and
241 deletions
+169
-241
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
...ssalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
+169
-241
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
View file @
f28c0213
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
...
@@ -28,24 +29,13 @@
...
@@ -28,24 +29,13 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
**/
template
<
int
N
,
typename
T_grad
,
typename
T_weight
>
template
<
int
N
,
typename
T_grad
,
typename
T_weight
>
struct
SGDFunctor
struct
SGDFunctor
{
{
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
N
>
&
tl
,
volatile
int
*
noop_gmem
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
TensorListMetadata
<
N
>
&
tl
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
// Early exit if we don't need to do anything
// Early exit if we don't need to do anything
if
(
*
noop_gmem
)
if
(
*
noop_gmem
)
return
;
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
...
@@ -61,8 +51,7 @@ struct SGDFunctor
...
@@ -61,8 +51,7 @@ struct SGDFunctor
mom_in
+=
chunk_idx
*
chunk_size
;
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
if
(
N
==
4
)
{
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
}
...
@@ -73,19 +62,15 @@ struct SGDFunctor
...
@@ -73,19 +62,15 @@ struct SGDFunctor
float
incoming_grads
[
ILP
];
float
incoming_grads
[
ILP
];
float
incoming_weights
[
ILP
];
float
incoming_weights
[
ILP
];
float
incoming_moms
[
ILP
];
float
incoming_moms
[
ILP
];
for
(
int
i_start
=
0
;
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
incoming_grads
[
ii
]
=
0
;
incoming_grads
[
ii
]
=
0
;
incoming_weights
[
ii
]
=
0
;
incoming_weights
[
ii
]
=
0
;
incoming_moms
[
ii
]
=
0
;
incoming_moms
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
{
{
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
])
*
scale
;
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
])
*
scale
;
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
...
@@ -98,19 +83,17 @@ struct SGDFunctor
...
@@ -98,19 +83,17 @@ struct SGDFunctor
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
{
{
// apply weight decay before momentum if necessary
// apply weight decay before momentum if necessary
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
if
(
momentum
!=
0.
f
)
if
(
momentum
!=
0.
f
)
{
{
if
(
!
first_run
)
if
(
!
first_run
)
incoming_moms
[
ii
]
=
incoming_moms
[
ii
]
*
momentum
+
(
1.
f
-
dampening
)
*
incoming_grads
[
ii
];
incoming_moms
[
ii
]
=
incoming_moms
[
ii
]
*
momentum
+
(
1.
f
-
dampening
)
*
incoming_grads
[
ii
];
else
// initialize momentums to current incoming grads
else
// initialize momentums to current incoming grads
incoming_moms
[
ii
]
=
incoming_grads
[
ii
];
incoming_moms
[
ii
]
=
incoming_grads
[
ii
];
...
@@ -132,27 +115,18 @@ struct SGDFunctor
...
@@ -132,27 +115,18 @@ struct SGDFunctor
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
// also write out the new momentum
// also write out the new momentum
if
(
momentum
!=
0.
f
)
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
mom_in
[
i
]
=
incoming_moms
[
ii
];
}
}
}
}
}
}
}
}
};
};
void
multi_tensor_sgd_cuda
(
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
wd
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
float
momentum
,
bool
nesterov
,
bool
first_run
,
float
dampening
,
bool
wd_after_momentum
,
float
scale
)
{
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
auto
num_tensors
=
tensor_lists
.
size
();
auto
num_tensors
=
tensor_lists
.
size
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
...
@@ -162,7 +136,8 @@ void multi_tensor_sgd_cuda(
...
@@ -162,7 +136,8 @@ void multi_tensor_sgd_cuda(
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Additional output tensors should always be fp16."
);
"Additional output tensors should always be fp16."
);
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
"expected noop flag to be on the same device as tensors"
);
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
"expected noop flag to be on the same device as tensors"
);
// We have 3 possibilities to handle here, in terms of
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// grad_type, param_type, momentum_type, requires_fp16_copy
...
@@ -176,22 +151,10 @@ void multi_tensor_sgd_cuda(
...
@@ -176,22 +151,10 @@ void multi_tensor_sgd_cuda(
// Case 1. fp16, fp16, fp16, No
// Case 1. fp16, fp16, fp16, No
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
num_tensors
==
3
)
{
num_tensors
==
3
)
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
{
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
(),
wd
,
momentum
,
multi_tensor_apply
<
3
>
(
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
scale
);
}
}
// Case 2. fp16, fp32, fp32, No
// Case 2. fp16, fp32, fp32, No
...
@@ -214,68 +177,33 @@ void multi_tensor_sgd_cuda(
...
@@ -214,68 +177,33 @@ void multi_tensor_sgd_cuda(
// }
// }
// Case 2. fp32, fp32, fp32, No
// Case 2. fp32, fp32, fp32, No
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
num_tensors
==
3
)
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
{
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
multi_tensor_apply
<
3
>
(
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
scale
);
}
}
// Case 3. fp16, fp32, fp32, Yes
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
num_tensors
==
4
)
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
{
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
multi_tensor_apply
<
4
>
(
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
scale
);
}
}
// Case 4. fp32, fp32, fp32, Yes
// Case 4. fp32, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
num_tensors
==
4
)
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
{
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
multi_tensor_apply
<
4
>
(
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
scale
);
}
}
else
{
else
AT_ERROR
(
{
"multi_tensor_sgd only supports some combinations of gradient & weight "
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
"types. Given: "
,
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
}
}
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment