Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2472 additions
and
1045 deletions
+2472
-1045
oneflow/core/common/util.h
oneflow/core/common/util.h
+1
-44
oneflow/core/common/wrap_dim_utils.h
oneflow/core/common/wrap_dim_utils.h
+20
-0
oneflow/core/control/rank_info_bootstrap_server.cpp
oneflow/core/control/rank_info_bootstrap_server.cpp
+70
-3
oneflow/core/control/rank_info_bootstrap_server.h
oneflow/core/control/rank_info_bootstrap_server.h
+6
-1
oneflow/core/control/rpc_client.cpp
oneflow/core/control/rpc_client.cpp
+22
-8
oneflow/core/cuda/atomic.cuh
oneflow/core/cuda/atomic.cuh
+149
-31
oneflow/core/cuda/elementwise.cuh
oneflow/core/cuda/elementwise.cuh
+43
-39
oneflow/core/cuda/layer_norm.cuh
oneflow/core/cuda/layer_norm.cuh
+603
-374
oneflow/core/cuda/rms_norm.cuh
oneflow/core/cuda/rms_norm.cuh
+1031
-0
oneflow/core/cuda/softmax.cuh
oneflow/core/cuda/softmax.cuh
+146
-121
oneflow/core/cuda/unique.cuh
oneflow/core/cuda/unique.cuh
+74
-47
oneflow/core/device/cuda_util.cpp
oneflow/core/device/cuda_util.cpp
+50
-6
oneflow/core/device/cuda_util.cu
oneflow/core/device/cuda_util.cu
+0
-34
oneflow/core/device/cuda_util.h
oneflow/core/device/cuda_util.h
+20
-7
oneflow/core/device/cudnn_conv_util.cpp
oneflow/core/device/cudnn_conv_util.cpp
+91
-156
oneflow/core/device/cudnn_conv_util.h
oneflow/core/device/cudnn_conv_util.h
+6
-16
oneflow/core/device/cudnn_util.cpp
oneflow/core/device/cudnn_util.cpp
+108
-17
oneflow/core/device/cudnn_util.h
oneflow/core/device/cudnn_util.h
+32
-0
oneflow/core/eager/blob_instruction_type.cpp
oneflow/core/eager/blob_instruction_type.cpp
+0
-45
oneflow/core/eager/blob_instruction_type.h
oneflow/core/eager/blob_instruction_type.h
+0
-96
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/common/util.h
View file @
a715222c
...
...
@@ -38,56 +38,13 @@ limitations under the License.
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/meta_util.hpp"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/common/hash.h"
#include "oneflow/core/common/cpp_attribute.h"
#define CHECK_ISNULL(e) CHECK((e) == nullptr)
namespace
oneflow
{
inline
size_t
HashCombine
(
size_t
lhs
,
size_t
rhs
)
{
return
lhs
^
(
rhs
+
0x9e3779b9
+
(
lhs
<<
6U
)
+
(
lhs
>>
2U
));
}
inline
void
HashCombine
(
size_t
*
seed
,
size_t
hash
)
{
*
seed
=
HashCombine
(
*
seed
,
hash
);
}
template
<
typename
...
T
>
inline
void
AddHash
(
size_t
*
seed
,
const
T
&
...
v
)
{
__attribute__
((
__unused__
))
int
dummy
[]
=
{(
HashCombine
(
seed
,
std
::
hash
<
T
>
()(
v
)),
0
)...};
}
template
<
typename
T
,
typename
...
Ts
>
inline
size_t
Hash
(
const
T
&
v1
,
const
Ts
&
...
vn
)
{
size_t
seed
=
std
::
hash
<
T
>
()(
v1
);
AddHash
<
Ts
...
>
(
&
seed
,
vn
...);
return
seed
;
}
}
// namespace oneflow
namespace
std
{
template
<
typename
T0
,
typename
T1
>
struct
hash
<
std
::
pair
<
T0
,
T1
>>
{
std
::
size_t
operator
()(
const
std
::
pair
<
T0
,
T1
>&
p
)
const
{
return
oneflow
::
Hash
<
T0
,
T1
>
(
p
.
first
,
p
.
second
);
}
};
template
<
typename
T
>
struct
hash
<
std
::
vector
<
T
>>
{
std
::
size_t
operator
()(
const
std
::
vector
<
T
>&
vec
)
const
{
std
::
size_t
hash_value
=
vec
.
size
();
for
(
const
auto
&
elem
:
vec
)
{
oneflow
::
AddHash
<
T
>
(
&
hash_value
,
elem
);
}
return
hash_value
;
}
};
}
// namespace std
namespace
oneflow
{
#define OF_DISALLOW_COPY(ClassName) \
ClassName(const ClassName&) = delete; \
ClassName& operator=(const ClassName&) = delete
...
...
oneflow/core/common/wrap_dim_utils.h
View file @
a715222c
...
...
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <bitset>
#include "oneflow/core/common/maybe.h"
namespace
oneflow
{
...
...
@@ -37,4 +38,23 @@ static inline Maybe<int64_t> maybe_wrap_dim(int64_t dim, int64_t dim_post_expr,
if
(
dim
<
0
)
dim
+=
dim_post_expr
;
return
dim
;
}
// align with pytorch: `aten/src/ATen/WrapDimUtilsMulti.h`
constexpr
size_t
dim_bitset_size
=
64
;
static
inline
Maybe
<
std
::
bitset
<
dim_bitset_size
>>
dim_list_to_bitset
(
const
std
::
vector
<
int32_t
>&
dims
,
int64_t
ndims
)
{
CHECK_LE_OR_RETURN
(
ndims
,
(
int64_t
)
dim_bitset_size
)
<<
Error
::
RuntimeError
()
<<
"Only tensors with up to "
<<
dim_bitset_size
<<
" dims are supported"
;
std
::
bitset
<
dim_bitset_size
>
seen
;
for
(
int32_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
size_t
dim
=
JUST
(
maybe_wrap_dim
(
dims
[
i
],
ndims
));
CHECK_OR_RETURN_ERROR
(
!
seen
[
dim
])
<<
Error
::
RuntimeError
()
<<
"The dim "
<<
dim
<<
" appears multiple times in the list of dims"
;
seen
[
dim
]
=
true
;
}
return
seen
;
}
}
// namespace oneflow
oneflow/core/control/rank_info_bootstrap_server.cpp
View file @
a715222c
...
...
@@ -13,8 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/rank_info_bootstrap_server.h"
#include <thread>
#include <mutex>
#include <chrono>
#include "grpc/grpc_posix.h"
#include "oneflow/core/common/env_var/bootstrap.h"
#include "oneflow/core/control/rank_info_bootstrap_server.h"
namespace
oneflow
{
...
...
@@ -29,12 +33,25 @@ std::string GetHostFromUri(const std::string& uri) {
return
uri
.
substr
(
first_delimiter_pos
+
1
,
second_delimiter_pos
-
first_delimiter_pos
-
1
);
}
int64_t
rpc_bootstrap_server_sleep_seconds
()
{
static
const
int64_t
rpc_bootstrap_server_sleep_seconds
=
EnvInteger
<
ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS
>
();
return
rpc_bootstrap_server_sleep_seconds
;
}
int64_t
rpc_bootstrap_server_max_retry_times
()
{
static
const
int64_t
rpc_bootstrap_server_max_retry_times
=
EnvInteger
<
ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES
>
();
return
rpc_bootstrap_server_max_retry_times
;
}
}
// namespace
RankInfoBootstrapServer
::
RankInfoBootstrapServer
(
const
BootstrapConf
&
bootstrap_conf
)
:
BootstrapServer
(),
port_
(
0
),
world_size_
(
bootstrap_conf
.
world_size
())
{
Init
();
int
p
=
(
bootstrap_conf
.
rank
()
==
0
?
bootstrap_conf
.
master_addr
().
port
()
:
0
);
const
int64_t
rank
=
bootstrap_conf
.
rank
();
int
p
=
(
rank
==
0
?
bootstrap_conf
.
master_addr
().
port
()
:
0
);
grpc
::
ServerBuilder
server_builder
;
server_builder
.
SetMaxMessageSize
(
INT_MAX
);
server_builder
.
AddListeningPort
(
"0.0.0.0:"
+
std
::
to_string
(
p
),
grpc
::
InsecureServerCredentials
(),
...
...
@@ -43,10 +60,59 @@ RankInfoBootstrapServer::RankInfoBootstrapServer(const BootstrapConf& bootstrap_
server_builder
.
RegisterService
(
grpc_service_
.
get
());
cq_
=
server_builder
.
AddCompletionQueue
();
grpc_server_
=
server_builder
.
BuildAndStart
();
if
(
bootstrap_conf
.
rank
()
==
0
)
{
CHECK_EQ
(
p
,
port
())
<<
"Port "
<<
p
<<
" is unavailable"
;
}
if
(
rank
==
0
)
{
CHECK_EQ
(
p
,
port
())
<<
"Port "
<<
p
<<
" is unavailable"
;
}
LOG
(
INFO
)
<<
"RankInfoBootstrapServer listening on "
<<
"0.0.0.0:"
+
std
::
to_string
(
port
());
loop_thread_
=
std
::
thread
(
&
RankInfoBootstrapServer
::
HandleRpcs
,
this
);
if
(
rank
==
0
)
{
rank2host_
=
std
::
make_shared
<
std
::
vector
<
std
::
string
>>
(
world_size_
,
""
);
// NOTE: use check_thread_ to check RankInfoBootstrapServer status on rank 0
// if size of ready ranks == total ranks(world_size), means status is ok.
// otherwise, it indicates that other ranks' server have not been created successfully!
check_thread_
=
std
::
thread
(
&
RankInfoBootstrapServer
::
CheckServerStatus
,
this
);
}
}
void
RankInfoBootstrapServer
::
CheckServerStatus
()
{
bool
status_ok
=
false
;
int64_t
skip_warning_times
=
1
;
int64_t
retry_idx
=
0
;
// lambda function to get valid rank num of rank2host_
auto
GetValidRank2HostSize
=
[](
const
std
::
shared_ptr
<
std
::
vector
<
std
::
string
>>&
rank2host
)
{
int64_t
valid_size
=
0
;
for
(
int64_t
i
=
0
;
i
<
rank2host
->
size
();
++
i
)
{
if
(
rank2host
->
at
(
i
)
==
""
)
{
continue
;
}
valid_size
+=
1
;
}
return
valid_size
;
};
for
(;
retry_idx
<
rpc_bootstrap_server_max_retry_times
();
++
retry_idx
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
rpc_bootstrap_server_sleep_seconds
()));
int64_t
valid_size
=
0
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
lock_
);
valid_size
=
GetValidRank2HostSize
(
rank2host_
);
}
CHECK
(
valid_size
<=
world_size_
);
if
(
valid_size
==
world_size_
)
{
status_ok
=
true
;
break
;
}
else
{
if
(
retry_idx
>=
skip_warning_times
)
{
LOG
(
WARNING
)
<<
"BootstrapServer not ready, rpc server on some rank have not been created "
"successfully. Failed at "
<<
retry_idx
+
1
<<
" times, total ranks(world_size): "
<<
world_size_
<<
", ready ranks: "
<<
valid_size
;
}
}
}
if
(
!
status_ok
)
{
LOG
(
FATAL
)
<<
"CheckServerStatus() failed, rpc server on some rank are not ready, please check "
"whether the processes on all ranks are "
"created successfully."
;
}
}
Maybe
<
const
std
::
vector
<
std
::
string
>&>
RankInfoBootstrapServer
::
rank2host
()
const
{
...
...
@@ -59,6 +125,7 @@ void RankInfoBootstrapServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* ca
CHECK_GE
(
rank
,
0
);
CHECK_LT
(
rank
,
world_size_
);
if
(
!
rank2host_
)
{
rank2host_
=
std
::
make_shared
<
std
::
vector
<
std
::
string
>>
(
world_size_
);
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
lock_
);
rank2host_
->
at
(
rank
)
=
GetHostFromUri
(
call
->
server_ctx
().
peer
());
call
->
SendResponse
();
EnqueueRequest
<
CtrlMethod
::
kLoadServer
>
();
...
...
oneflow/core/control/rank_info_bootstrap_server.h
View file @
a715222c
...
...
@@ -26,7 +26,9 @@ namespace oneflow {
class
RankInfoBootstrapServer
final
:
public
BootstrapServer
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
RankInfoBootstrapServer
);
~
RankInfoBootstrapServer
()
override
=
default
;
~
RankInfoBootstrapServer
()
override
{
if
(
check_thread_
.
joinable
())
{
check_thread_
.
join
();
}
}
RankInfoBootstrapServer
(
const
BootstrapConf
&
bootstrap_conf
);
...
...
@@ -35,9 +37,12 @@ class RankInfoBootstrapServer final : public BootstrapServer {
private:
void
OnLoadServer
(
CtrlCall
<
CtrlMethod
::
kLoadServer
>*
call
)
override
;
void
CheckServerStatus
();
int
port_
;
const
int64_t
world_size_
;
std
::
mutex
lock_
;
std
::
thread
check_thread_
;
// use std::shared_ptr as std::optional
std
::
shared_ptr
<
std
::
vector
<
std
::
string
>>
rank2host_
;
};
...
...
oneflow/core/control/rpc_client.cpp
View file @
a715222c
...
...
@@ -16,13 +16,22 @@ limitations under the License.
#include "oneflow/core/control/rpc_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/common/env_var/bootstrap.h"
namespace
oneflow
{
namespace
{
const
int32_t
max_retry_num
=
60
;
const
int64_t
sleep_seconds
=
10
;
int64_t
rpc_client_max_retry_times
()
{
static
const
int64_t
rpc_client_max_retry_times
=
EnvInteger
<
ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES
>
();
return
rpc_client_max_retry_times
;
}
int64_t
rpc_client_sleep_seconds
()
{
static
const
int64_t
rpc_client_sleep_seconds
=
EnvInteger
<
ONEFLOW_RPC_CLIENT_SLEEP_SECONDS
>
();
return
rpc_client_sleep_seconds
;
}
#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)
...
...
@@ -179,23 +188,28 @@ void RpcClient::LoadServer(const std::string& server_addr, CtrlService::Stub* st
void
RpcClient
::
LoadServer
(
const
LoadServerRequest
&
request
,
CtrlService
::
Stub
*
stub
)
{
int32_t
retry_idx
=
0
;
for
(;
retry_idx
<
max_retry_num
;
++
retry_idx
)
{
int32_t
skip_warning_times
=
3
;
for
(;
retry_idx
<
rpc_client_max_retry_times
();
++
retry_idx
)
{
grpc
::
ClientContext
client_ctx
;
LoadServerResponse
response
;
grpc
::
Status
st
=
stub
->
CallMethod
<
CtrlMethod
::
kLoadServer
>
(
&
client_ctx
,
request
,
&
response
);
if
(
st
.
error_code
()
==
grpc
::
StatusCode
::
OK
)
{
VLOG
(
3
)
<<
"LoadServer "
<<
request
.
addr
()
<<
" Successful at "
<<
retry_idx
<<
" times"
;
VLOG
(
3
)
<<
"LoadServer "
<<
request
.
addr
()
<<
" Successful at "
<<
retry_idx
+
1
<<
" times"
;
break
;
}
else
if
(
st
.
error_code
()
==
grpc
::
StatusCode
::
UNAVAILABLE
)
{
LOG
(
WARNING
)
<<
"LoadServer "
<<
request
.
addr
()
<<
" Failed at "
<<
retry_idx
<<
" times"
<<
" error_code "
<<
st
.
error_code
()
<<
" error_message "
<<
st
.
error_message
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
sleep_seconds
));
if
(
retry_idx
>=
skip_warning_times
)
{
LOG
(
WARNING
)
<<
"LoadServer "
<<
request
.
addr
()
<<
" Failed at "
<<
retry_idx
+
1
<<
" times"
<<
" error_code: "
<<
st
.
error_code
()
<<
" error_message: "
<<
st
.
error_message
();
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
rpc_client_sleep_seconds
()));
continue
;
}
else
{
LOG
(
FATAL
)
<<
st
.
error_message
();
}
}
CHECK_LT
(
retry_idx
,
max_retry_
num
);
CHECK_LT
(
retry_idx
,
rpc_client_
max_retry_
times
()
);
}
CtrlService
::
Stub
*
RpcClient
::
GetThisStub
()
{
return
stubs_
[
GlobalProcessCtx
::
Rank
()].
get
();
}
...
...
oneflow/core/cuda/atomic.cuh
View file @
a715222c
...
...
@@ -16,7 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_ATOMIC_H_
#define ONEFLOW_CORE_CUDA_ATOMIC_H_
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIPCC__)
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#else
#include <cuda.h>
#include <cuda_runtime.h>
...
...
@@ -25,6 +32,9 @@ limitations under the License.
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#endif
namespace
oneflow
{
namespace
cuda
{
...
...
@@ -34,58 +44,90 @@ namespace atomic {
namespace
internal
{
template
<
typename
T
,
typename
U
>
__device__
__forceinline__
T
CastCASImpl
(
T
*
address
,
T
compare
,
T
val
)
{
static_assert
(
sizeof
(
T
)
==
sizeof
(
U
),
""
);
U
ret
=
atomicCAS
(
reinterpret_cast
<
U
*>
(
address
),
*
(
reinterpret_cast
<
U
*>
(
&
compare
)),
*
(
reinterpret_cast
<
U
*>
(
&
val
)));
return
*
(
reinterpret_cast
<
T
*>
(
&
ret
));
}
struct
CastCASImpl
{
__device__
__forceinline__
T
operator
()(
T
*
address
,
T
compare
,
T
val
,
bool
*
success
)
const
{
static_assert
(
sizeof
(
T
)
==
sizeof
(
U
),
""
);
U
assumed
=
*
(
reinterpret_cast
<
U
*>
(
&
compare
));
U
ret
=
atomicCAS
(
reinterpret_cast
<
U
*>
(
address
),
assumed
,
*
(
reinterpret_cast
<
U
*>
(
&
val
)));
*
success
=
(
ret
==
assumed
);
return
*
(
reinterpret_cast
<
T
*>
(
&
ret
));
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
template
<
typename
T
>
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
__device__
__forceinline__
T
operator
()(
T
*
address
,
T
compare
,
T
val
,
bool
*
success
)
const
{
static_assert
(
sizeof
(
T
)
==
sizeof
(
unsigned
short
int
),
""
);
size_t
offset
=
reinterpret_cast
<
size_t
>
(
address
)
&
0x2
;
unsigned
int
*
address_as_ui
=
reinterpret_cast
<
unsigned
int
*>
(
reinterpret_cast
<
char
*>
(
address
)
-
offset
);
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
=
*
(
reinterpret_cast
<
unsigned
short
int
*>
(
&
compare
));
unsigned
int
newval
=
*
(
reinterpret_cast
<
unsigned
short
int
*>
(
&
val
));
assumed
=
offset
?
(
old
&
0xffff
)
|
(
assumed
<<
16
)
:
(
old
&
0xffff0000
)
|
assumed
;
newval
=
offset
?
(
old
&
0xffff
)
|
(
newval
<<
16
)
:
(
old
&
0xffff0000
)
|
newval
;
unsigned
int
ret
=
atomicCAS
(
address_as_ui
,
assumed
,
newval
);
*
success
=
(
ret
==
assumed
);
ret
=
offset
?
(
ret
>>
16
)
:
(
ret
&
0xffff
);
return
*
(
reinterpret_cast
<
T
*>
(
&
ret
));
}
};
#endif // __CUDA_ARCH__
template
<
typename
T
>
__device__
__forceinline__
typename
std
::
enable_if
<
sizeof
(
T
)
==
sizeof
(
unsigned
int
),
T
>::
type
CASImpl
(
T
*
address
,
T
compare
,
T
val
)
{
return
CastCASImpl
<
T
,
unsigned
int
>
(
address
,
compare
,
val
);
CASImpl
(
T
*
address
,
T
compare
,
T
val
,
bool
*
success
)
{
return
CastCASImpl
<
T
,
unsigned
int
>
(
)(
address
,
compare
,
val
,
success
);
}
template
<
typename
T
>
__device__
__forceinline__
typename
std
::
enable_if
<
sizeof
(
T
)
==
sizeof
(
unsigned
long
long
int
),
T
>::
type
CASImpl
(
T
*
address
,
T
compare
,
T
val
)
{
return
CastCASImpl
<
T
,
unsigned
long
long
int
>
(
address
,
compare
,
val
);
CASImpl
(
T
*
address
,
T
compare
,
T
val
,
bool
*
success
)
{
return
CastCASImpl
<
T
,
unsigned
long
long
int
>
(
)(
address
,
compare
,
val
,
success
);
}
template
<
typename
T
>
__device__
__forceinline__
typename
std
::
enable_if
<
sizeof
(
T
)
==
sizeof
(
unsigned
short
int
),
T
>::
type
CASImpl
(
T
*
address
,
T
compare
,
T
val
)
{
#if __CUDA_ARCH__ >= 700
return
CastCASImpl
<
T
,
unsigned
short
int
>
(
address
,
compare
,
val
);
#else
__trap
();
return
0
;
#endif // __CUDA_ARCH__ >= 700
CASImpl
(
T
*
address
,
T
compare
,
T
val
,
bool
*
success
)
{
return
CastCASImpl
<
T
,
unsigned
short
int
>
()(
address
,
compare
,
val
,
success
);
}
__device__
__forceinline__
int
CASImpl
(
int
*
address
,
int
compare
,
int
val
)
{
return
atomicCAS
(
address
,
compare
,
val
);
__device__
__forceinline__
int
CASImpl
(
int
*
address
,
int
compare
,
int
val
,
bool
*
success
)
{
int
ret
=
atomicCAS
(
address
,
compare
,
val
);
*
success
=
(
ret
==
compare
);
return
ret
;
}
__device__
__forceinline__
unsigned
int
CASImpl
(
unsigned
int
*
address
,
unsigned
int
compare
,
unsigned
int
val
)
{
return
atomicCAS
(
address
,
compare
,
val
);
unsigned
int
val
,
bool
*
success
)
{
unsigned
int
ret
=
atomicCAS
(
address
,
compare
,
val
);
*
success
=
(
ret
==
compare
);
return
ret
;
}
__device__
__forceinline__
unsigned
long
long
int
CASImpl
(
unsigned
long
long
int
*
address
,
unsigned
long
long
int
compare
,
unsigned
long
long
int
val
)
{
return
atomicCAS
(
address
,
compare
,
val
);
unsigned
long
long
int
val
,
bool
*
success
)
{
unsigned
long
long
int
ret
=
atomicCAS
(
address
,
compare
,
val
);
*
success
=
(
ret
==
compare
);
return
ret
;
}
#if __CUDA_ARCH__ >= 700
__device__
__forceinline__
unsigned
short
int
CASImpl
(
unsigned
short
int
*
address
,
unsigned
short
int
compare
,
unsigned
short
int
val
)
{
return
atomicCAS
(
address
,
compare
,
val
);
unsigned
short
int
val
,
bool
*
success
)
{
unsigned
short
int
ret
=
atomicCAS
(
address
,
compare
,
val
);
*
success
=
(
ret
==
compare
);
return
ret
;
}
#endif // __CUDA_ARCH__ >= 700
...
...
@@ -99,10 +141,11 @@ template<typename T, template<typename> class BinaryOp>
__device__
__forceinline__
T
AtomicCASBinaryImpl
(
T
*
address
,
T
val
)
{
T
old
=
*
address
;
T
assumed
;
bool
success
=
false
;
do
{
assumed
=
old
;
old
=
CASImpl
(
address
,
assumed
,
BinaryOp
<
T
>
()(
old
,
val
));
}
while
(
old
!=
assumed
);
old
=
CASImpl
(
address
,
assumed
,
BinaryOp
<
T
>
()(
old
,
val
)
,
&
success
);
}
while
(
!
success
);
return
old
;
}
...
...
@@ -156,17 +199,41 @@ __device__ __forceinline__ nv_bfloat16 AddImpl(nv_bfloat16* address, nv_bfloat16
return
atomicAdd
(
address
,
val
);
}
__device__
__forceinline__
nv_bfloat162
AddImpl
(
nv_bfloat162
*
address
,
nv_bfloat162
val
)
{
return
atomicAdd
(
address
,
val
);
}
#endif // __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ < 530
#if
(
__CUDA_ARCH__ < 530
) && !defined(WITH_ROCM)
__device__
__forceinline__
half2
AddImpl
(
half2
*
address
,
half2
val
)
{
__trap
();
TRAP
();
return
val
;
}
#endif // __CUDA_ARCH__ < 530
#ifdef WITH_ROCM
__device__
__forceinline__
double
AddImpl
(
double
*
address
,
double
val
)
{
return
atomicAdd
(
address
,
val
);
}
__device__
__forceinline__
half
AddImpl
(
half
*
address
,
half
val
)
{
float
address_value
=
__half2float
(
*
address
);
return
__float2half
(
atomicAdd
(
&
address_value
,
__half2float
(
val
)));
}
__device__
__forceinline__
half2
AddImpl
(
half2
*
address
,
half2
val
)
{
half2
res
;
float2
address_value
=
__half22float2
(
*
address
);
res
.
data
.
x
=
__float2half
(
atomicAdd
(
&
address_value
.
x
,
__half2float
(
val
.
data
.
x
)));
res
.
data
.
y
=
__float2half
(
atomicAdd
(
&
address_value
.
y
,
__half2float
(
val
.
data
.
y
)));
return
res
;
}
#endif
}
// namespace internal
template
<
typename
T
,
typename
U
>
...
...
@@ -181,7 +248,8 @@ __device__ __forceinline__ typename std::enable_if<std::is_same<T, U>::value, T>
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
__forceinline__
T
CAS
(
T
*
address
,
U
compare
,
V
val
)
{
return
internal
::
CASImpl
(
address
,
Cast
<
T
>
(
compare
),
Cast
<
T
>
(
val
));
bool
success
=
false
;
return
internal
::
CASImpl
(
address
,
Cast
<
T
>
(
compare
),
Cast
<
T
>
(
val
),
&
success
);
}
template
<
typename
T
,
typename
U
>
...
...
@@ -189,6 +257,56 @@ __device__ __forceinline__ T Add(T* address, U val) {
return
internal
::
AddImpl
(
address
,
Cast
<
T
>
(
val
));
}
__device__
__forceinline__
float
Mul
(
int32_t
*
address
,
const
int32_t
val
)
{
int32_t
old
=
*
address
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address
,
assumed
,
val
*
assumed
);
}
while
(
assumed
!=
old
);
return
old
;
}
__device__
__forceinline__
float
Mul
(
uint32_t
*
address
,
const
uint32_t
val
)
{
uint32_t
old
=
*
address
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address
,
assumed
,
val
*
assumed
);
}
while
(
assumed
!=
old
);
return
old
;
}
__device__
__forceinline__
float
Mul
(
uint64_t
*
address
,
const
uint64_t
val
)
{
static_assert
(
sizeof
(
uint64_t
)
==
sizeof
(
unsigned
long
long
int
),
""
);
unsigned
long
long
int
old
=
*
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
assumed
,
static_cast
<
unsigned
long
long
int
>
(
val
)
*
assumed
);
}
while
(
assumed
!=
old
);
return
old
;
}
__device__
__forceinline__
float
Mul
(
float
*
address
,
const
float
val
)
{
int32_t
*
address_as_int
=
reinterpret_cast
<
int32_t
*>
(
address
);
int32_t
old
=
*
address_as_int
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_int
,
assumed
,
__float_as_int
(
val
*
__int_as_float
(
assumed
)));
}
while
(
assumed
!=
old
);
return
__int_as_float
(
old
);
}
__device__
__forceinline__
float
Mul
(
double
*
address
,
const
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
*
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
return
__longlong_as_double
(
old
);
}
__device__
__forceinline__
float
Max
(
float
*
address
,
const
float
val
)
{
int
*
address_as_i
=
(
int
*
)
address
;
int
old
=
*
address_as_i
;
...
...
oneflow/core/cuda/elementwise.cuh
View file @
a715222c
...
...
@@ -16,7 +16,13 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_ELEMENTWISE_H_
#define ONEFLOW_CORE_CUDA_ELEMENTWISE_H_
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda_runtime.h>
#endif
#include "oneflow/core/ep/include/gpu_macro.h"
#include <cstdint>
#include <algorithm>
#include <type_traits>
...
...
@@ -30,25 +36,25 @@ namespace elementwise {
constexpr
int
kBlockSize
=
256
;
constexpr
int
kNumWaves
=
32
;
inline
cuda
Error_t
GetNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
inline
GPU
(
Error_t
)
GetNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
int
dev
;
{
cuda
Error_t
err
=
cuda
GetDevice
(
&
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
GetDevice
)
(
&
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
sm_count
;
{
cuda
Error_t
err
=
cuda
DeviceGetAttribute
(
&
sm_count
,
cudaDevAttr
MultiProcessorCount
,
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)
(
&
sm_count
,
GPU
MultiProcessorCount
,
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
tpm
;
{
cuda
Error_t
err
=
cuda
DeviceGetAttribute
(
&
tpm
,
cudaDevAttr
MaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)
(
&
tpm
,
GPU
MaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
((
n
+
kBlockSize
-
1
)
/
kBlockSize
,
sm_count
*
tpm
/
kBlockSize
*
kNumWaves
));
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
T
,
int
pack_size
>
...
...
@@ -113,24 +119,24 @@ class HasApply2 {
template
<
int
pack_size
,
typename
FunctorT
,
typename
R
,
typename
...
IN
>
__device__
typename
std
::
enable_if
<
HasApply2
<
FunctorT
>::
value
==
true
&&
pack_size
%
2
==
0
,
Packed
<
R
,
pack_size
>>::
type
ApplyPack
(
const
FunctorT
&
functor
,
const
IN
...
in
[
pack_size
]
)
{
ApplyPack
(
const
FunctorT
&
functor
,
const
Packed
<
IN
,
pack_size
>
...
in
)
{
Packed
<
R
,
pack_size
>
ret
;
#pragma unroll
for
(
int
j
=
0
;
j
<
pack_size
;
j
+=
2
)
{
functor
.
Apply2
(
ret
.
elem
+
j
,
(
in
+
j
)...);
}
for
(
int
j
=
0
;
j
<
pack_size
;
j
+=
2
)
{
functor
.
Apply2
(
ret
.
elem
+
j
,
(
in
.
elem
+
j
)...);
}
return
ret
;
}
template
<
int
pack_size
,
typename
FunctorT
,
typename
R
,
typename
...
IN
>
__device__
typename
std
::
enable_if
<
HasApply2
<
FunctorT
>::
value
==
false
||
pack_size
%
2
!=
0
,
Packed
<
R
,
pack_size
>>::
type
ApplyPack
(
const
FunctorT
&
functor
,
const
IN
...
in
[
pack_size
]
)
{
ApplyPack
(
const
FunctorT
&
functor
,
const
Packed
<
IN
,
pack_size
>
...
in
)
{
Packed
<
R
,
pack_size
>
ret
;
#pragma unroll
for
(
int
j
=
0
;
j
<
pack_size
;
++
j
)
{
ret
.
elem
[
j
]
=
functor
((
in
[
j
])...);
}
for
(
int
j
=
0
;
j
<
pack_size
;
++
j
)
{
ret
.
elem
[
j
]
=
functor
((
in
.
elem
[
j
])...);
}
return
ret
;
}
template
<
int
pack_size
,
bool
tail
,
typename
FactoryT
,
typename
R
,
typename
...
IN
>
template
<
int
pack_size
,
typename
FactoryT
,
typename
R
,
typename
...
IN
>
__global__
void
__launch_bounds__
(
kBlockSize
)
ApplyGeneric
(
FactoryT
factory
,
int64_t
n_pack
,
Packed
<
R
,
pack_size
>*
pack_r
,
const
Packed
<
IN
,
pack_size
>*
...
pack_in
,
int64_t
n_tail
,
R
*
tail_r
,
...
...
@@ -138,9 +144,9 @@ __global__ void __launch_bounds__(kBlockSize)
auto
functor
=
factory
();
const
int
global_tid
=
blockIdx
.
x
*
kBlockSize
+
threadIdx
.
x
;
for
(
int64_t
i
=
global_tid
;
i
<
n_pack
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
pack_r
[
i
]
=
ApplyPack
<
pack_size
,
decltype
(
functor
),
R
,
IN
...
>
(
functor
,
(
pack_in
[
i
]
.
elem
)...);
pack_r
[
i
]
=
ApplyPack
<
pack_size
,
decltype
(
functor
),
R
,
IN
...
>
(
functor
,
(
pack_in
[
i
])...);
}
if
(
tail
&&
global_tid
<
n_tail
)
{
tail_r
[
global_tid
]
=
functor
((
tail_in
[
global_tid
])...);
}
if
(
global_tid
<
n_tail
)
{
tail_r
[
global_tid
]
=
functor
((
tail_in
[
global_tid
])...);
}
}
template
<
typename
FunctorT
>
...
...
@@ -153,41 +159,39 @@ struct SimpleFactory {
};
template
<
size_t
pack_size
>
bool
IsAlig
e
ndForPack
()
{
bool
IsAlign
e
dForPack
()
{
return
true
;
}
template
<
size_t
pack_size
,
typename
T
,
typename
...
Args
>
bool
IsAlig
e
ndForPack
(
const
T
*
ptr
,
const
Args
*
...
others
)
{
bool
IsAlign
e
dForPack
(
const
T
*
ptr
,
const
Args
*
...
others
)
{
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
sizeof
(
Pack
<
T
,
pack_size
>
)
==
0
&&
IsAlig
e
ndForPack
<
pack_size
,
Args
...
>
(
others
...);
&&
IsAlign
e
dForPack
<
pack_size
,
Args
...
>
(
others
...);
}
template
<
size_t
pack_size
,
typename
FactoryT
,
typename
R
,
typename
...
IN
>
cuda
Error_t
LaunchKernel
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
IN
*
...
in
,
cuda
Stream_t
stream
)
{
GPU
(
Error_t
)
LaunchKernel
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
IN
*
...
in
,
GPU
(
Stream_t
)
stream
)
{
const
int64_t
n_pack
=
n
/
pack_size
;
const
int64_t
tail_offset
=
n_pack
*
pack_size
;
const
int64_t
n_tail
=
n
-
tail_offset
;
int
num_blocks
;
{
cuda
Error_t
err
=
GetNumBlocks
(
n_pack
,
&
num_blocks
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
n_pack
,
&
num_blocks
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
auto
func
=
n_tail
>
0
?
ApplyGeneric
<
pack_size
,
true
,
FactoryT
,
R
,
IN
...
>
:
ApplyGeneric
<
pack_size
,
false
,
FactoryT
,
R
,
IN
...
>
;
func
<<<
num_blocks
,
kBlockSize
,
0
,
stream
>>>
(
ApplyGeneric
<
pack_size
,
FactoryT
,
R
,
IN
...
><<<
num_blocks
,
kBlockSize
,
0
,
stream
>>>
(
factory
,
n_pack
,
reinterpret_cast
<
Packed
<
R
,
pack_size
>*>
(
r
),
(
reinterpret_cast
<
const
Packed
<
IN
,
pack_size
>*>
(
in
))...,
n_tail
,
r
+
tail_offset
,
(
in
+
tail_offset
)...);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
FactoryT
,
typename
R
,
typename
...
IN
>
struct
GenericLauncher
{
static
cuda
Error_t
Launch
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
IN
*
...
in
,
cuda
Stream_t
stream
)
{
static
GPU
(
Error_t
)
Launch
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
IN
*
...
in
,
GPU
(
Stream_t
)
stream
)
{
constexpr
int
max_pack_size
=
PackSize
<
R
,
IN
...
>
();
if
(
IsAlig
e
ndForPack
<
max_pack_size
,
R
,
IN
...
>
(
r
,
in
...))
{
if
(
IsAlign
e
dForPack
<
max_pack_size
,
R
,
IN
...
>
(
r
,
in
...))
{
return
LaunchKernel
<
max_pack_size
,
FactoryT
,
R
,
IN
...
>
(
factory
,
n
,
r
,
in
...,
stream
);
}
else
{
return
LaunchKernel
<
1
,
FactoryT
,
R
,
IN
...
>
(
factory
,
n
,
r
,
in
...,
stream
);
...
...
@@ -196,37 +200,37 @@ struct GenericLauncher {
};
template
<
typename
FactoryT
,
typename
R
,
typename
A
>
inline
cuda
Error_t
UnaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
UnaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
GPU
(
Stream_t
)
stream
)
{
return
GenericLauncher
<
FactoryT
,
R
,
A
>::
Launch
(
factory
,
n
,
r
,
a
,
stream
);
}
template
<
typename
FunctorT
,
typename
R
,
typename
A
>
inline
cuda
Error_t
Unary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
Unary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
GPU
(
Stream_t
)
stream
)
{
return
UnaryWithFactory
(
SimpleFactory
<
FunctorT
>
(
functor
),
n
,
r
,
a
,
stream
);
}
template
<
typename
FactoryT
,
typename
R
,
typename
A
,
typename
B
>
inline
cuda
Error_t
BinaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
BinaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
GPU
(
Stream_t
)
stream
)
{
return
GenericLauncher
<
FactoryT
,
R
,
A
,
B
>::
Launch
(
factory
,
n
,
r
,
a
,
b
,
stream
);
}
template
<
typename
FunctorT
,
typename
R
,
typename
A
,
typename
B
>
inline
cuda
Error_t
Binary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
Binary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
GPU
(
Stream_t
)
stream
)
{
return
BinaryWithFactory
(
SimpleFactory
<
FunctorT
>
(
functor
),
n
,
r
,
a
,
b
,
stream
);
}
template
<
typename
FactoryT
,
typename
R
,
typename
A
,
typename
B
,
typename
C
>
inline
cuda
Error_t
TernaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
const
C
*
c
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
TernaryWithFactory
(
FactoryT
factory
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
const
C
*
c
,
GPU
(
Stream_t
)
stream
)
{
return
GenericLauncher
<
FactoryT
,
R
,
A
,
B
,
C
>::
Launch
(
factory
,
n
,
r
,
a
,
b
,
c
,
stream
);
}
template
<
typename
FunctorT
,
typename
R
,
typename
A
,
typename
B
,
typename
C
>
inline
cuda
Error_t
Ternary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
const
C
*
c
,
cuda
Stream_t
stream
)
{
inline
GPU
(
Error_t
)
Ternary
(
FunctorT
functor
,
int64_t
n
,
R
*
r
,
const
A
*
a
,
const
B
*
b
,
const
C
*
c
,
GPU
(
Stream_t
)
stream
)
{
return
TernaryWithFactory
(
SimpleFactory
<
FunctorT
>
(
functor
),
n
,
r
,
a
,
b
,
c
,
stream
);
}
...
...
oneflow/core/cuda/layer_norm.cuh
View file @
a715222c
...
...
@@ -17,8 +17,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_LAYER_NORM_H_
#define ONEFLOW_CORE_CUDA_LAYER_NORM_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#include <math_constants.h>
#endif
#include <assert.h>
namespace
oneflow
{
...
...
@@ -27,7 +33,11 @@ namespace cuda {
namespace
layer_norm
{
#ifdef WITH_ROCM
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kWarpSize
=
32
;
#endif
template
<
typename
T
>
struct
SumOp
{
...
...
@@ -42,14 +52,22 @@ struct MaxOp {
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
thread_group_width
=
kWarpSize
>
__inline__
__device__
T
WarpAllReduce
(
T
val
)
{
for
(
int
mask
=
thread_group_width
/
2
;
mask
>
0
;
mask
/=
2
)
{
#ifdef WITH_ROCM
val
=
ReductionOp
<
T
>
()(
val
,
__shfl_xor
(
val
,
mask
,
thread_group_width
));
#else
val
=
ReductionOp
<
T
>
()(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
,
thread_group_width
));
#endif
}
return
val
;
}
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
block_size
>
__inline__
__device__
T
BlockAllReduce
(
T
val
)
{
#ifdef WITH_ROCM
typedef
hipcub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
#else
typedef
cub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
#endif
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
result_broadcast
;
T
result
=
BlockReduce
(
temp_storage
).
Reduce
(
val
,
ReductionOp
<
T
>
());
...
...
@@ -93,26 +111,26 @@ __inline__ __device__ double Rsqrt<double>(double x) {
}
template
<
class
Func
>
inline
cuda
Error_t
GetNumBlocks
(
Func
func
,
int64_t
block_size
,
size_t
dynamic_smem_size
,
inline
GPU
(
Error_t
)
GetNumBlocks
(
Func
func
,
int64_t
block_size
,
size_t
dynamic_smem_size
,
int64_t
max_blocks
,
int64_t
waves
,
int
*
num_blocks
)
{
int
dev
;
{
cuda
Error_t
err
=
cuda
GetDevice
(
&
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
GetDevice
)
(
&
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
sm_count
;
{
cuda
Error_t
err
=
cuda
DeviceGetAttribute
(
&
sm_count
,
cudaDevAttr
MultiProcessorCount
,
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)
(
&
sm_count
,
GPU
MultiProcessorCount
,
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
max_active_blocks
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
func
,
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks
,
func
,
block_size
,
dynamic_smem_size
);
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
(
max_blocks
,
sm_count
*
max_active_blocks
*
waves
));
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
T
>
...
...
@@ -132,6 +150,34 @@ struct DefaultComputeType<nv_bfloat16> {
};
#endif // CUDA_VERSION >= 11000
template
<
typename
T
>
class
HasCanPackAs
{
typedef
char
one
;
struct
two
{
char
x
[
2
];
};
template
<
typename
C
>
static
one
test
(
decltype
(
&
C
::
CanPackAs
));
template
<
typename
C
>
static
two
test
(...);
public:
enum
{
value
=
sizeof
(
test
<
T
>
(
0
))
==
sizeof
(
char
)
};
};
template
<
typename
T
>
typename
std
::
enable_if
<
HasCanPackAs
<
T
>::
value
==
true
,
bool
>::
type
CanPackAs
(
T
t
,
size_t
pack_size
)
{
return
t
.
CanPackAs
(
pack_size
);
}
template
<
typename
T
>
typename
std
::
enable_if
<
HasCanPackAs
<
T
>::
value
==
false
,
bool
>::
type
CanPackAs
(
T
t
,
size_t
pack_size
)
{
return
true
;
}
template
<
typename
T
,
int
N
>
struct
GetPackType
{
using
type
=
typename
std
::
aligned_storage
<
N
*
sizeof
(
T
),
N
*
sizeof
(
T
)
>::
type
;
...
...
@@ -152,6 +198,7 @@ union Pack {
template
<
typename
SRC
,
typename
DST
>
struct
DirectLoad
{
using
LoadType
=
DST
;
DirectLoad
(
const
SRC
*
src
,
int64_t
row_size
)
:
src
(
src
),
row_size
(
row_size
)
{}
template
<
int
N
>
__device__
void
load
(
DST
*
dst
,
int64_t
row
,
int64_t
col
)
const
{
...
...
@@ -210,9 +257,15 @@ __inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T threa
*
m2
=
thread_m2
;
*
count
=
thread_count
;
for
(
int
mask
=
thread_group_width
/
2
;
mask
>
0
;
mask
/=
2
)
{
#ifdef WITH_ROCM
T
b_mean
=
__shfl_down
(
*
mean
,
mask
,
thread_group_width
);
T
b_m2
=
__shfl_down
(
*
m2
,
mask
,
thread_group_width
);
T
b_count
=
__shfl_down
(
*
count
,
mask
,
thread_group_width
);
#else
T
b_mean
=
__shfl_down_sync
(
0xffffffff
,
*
mean
,
mask
,
thread_group_width
);
T
b_m2
=
__shfl_down_sync
(
0xffffffff
,
*
m2
,
mask
,
thread_group_width
);
T
b_count
=
__shfl_down_sync
(
0xffffffff
,
*
count
,
mask
,
thread_group_width
);
#endif
WelfordCombine
(
b_mean
,
b_m2
,
b_count
,
mean
,
m2
,
count
);
}
}
...
...
@@ -221,9 +274,16 @@ template<typename T, int thread_group_width = kWarpSize>
__inline__
__device__
void
WelfordWarpAllReduce
(
T
thread_mean
,
T
thread_m2
,
T
thread_count
,
T
*
mean
,
T
*
m2
,
T
*
count
)
{
WelfordWarpReduce
<
T
,
thread_group_width
>
(
thread_mean
,
thread_m2
,
thread_count
,
mean
,
m2
,
count
);
#ifdef WITH_ROCM
*
mean
=
__shfl
(
*
mean
,
0
,
thread_group_width
);
*
m2
=
__shfl
(
*
m2
,
0
,
thread_group_width
);
*
count
=
__shfl
(
*
count
,
0
,
thread_group_width
);
#else
*
mean
=
__shfl_sync
(
0xffffffff
,
*
mean
,
0
,
thread_group_width
);
*
m2
=
__shfl_sync
(
0xffffffff
,
*
m2
,
0
,
thread_group_width
);
*
count
=
__shfl_sync
(
0xffffffff
,
*
count
,
0
,
thread_group_width
);
#endif
}
template
<
typename
T
>
...
...
@@ -258,7 +318,11 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
T
block_mean
=
0
;
T
block_m2
=
0
;
T
block_count
=
0
;
...
...
@@ -275,17 +339,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
*
result_count
=
count_result_broadcast
;
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
__global__
void
LayerNormWarpImpl
(
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
static_assert
(
cols_per_thread
%
pack_size
==
0
,
""
);
using
LoadType
=
typename
LOAD
::
LoadType
;
static_assert
(
max_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
min_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
thread_group_width
<=
kWarpSize
,
""
);
static_assert
(
kWarpSize
%
thread_group_width
==
0
,
""
);
constexpr
int
num_packs
=
cols_per_thread
/
pack_size
;
assert
(
cols
<=
cols_per_thread
*
thread_group_width
);
ComputeType
buf
[
rows_per_access
][
cols_per_thread
];
constexpr
int
max_num_packs
=
max_cols_per_thread
/
pack_size
;
constexpr
int
min_num_packs
=
min_cols_per_thread
/
pack_size
;
assert
(
cols
<=
max_cols_per_thread
*
thread_group_width
);
ComputeType
buf
[
rows_per_access
][
max_cols_per_thread
];
const
int64_t
global_thread_group_id
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int64_t
num_global_thread_group
=
gridDim
.
x
*
blockDim
.
y
;
const
int64_t
lane_id
=
threadIdx
.
x
;
...
...
@@ -301,13 +369,27 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
thread_count
[
row_id
]
=
0
;
ComputeType
*
row_buf
=
buf
[
row_id
];
#pragma unroll
for
(
int
pack_id
=
0
;
pack_id
<
num_packs
;
++
pack_id
)
{
for
(
int
pack_id
=
0
;
pack_id
<
min_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
const
int
pack_offset
=
pack_id
*
pack_size
;
LoadType
pack
[
pack_size
];
load
.
template
load
<
pack_size
>(
pack
,
row
+
row_id
,
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
row_buf
[
pack_offset
+
i
]
=
static_cast
<
ComputeType
>
(
pack
[
i
]);
WelfordCombine
(
row_buf
[
pack_offset
+
i
],
thread_mean
+
row_id
,
thread_m2
+
row_id
,
thread_count
+
row_id
);
}
}
for
(
int
pack_id
=
min_num_packs
;
pack_id
<
max_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
const
int
pack_offset
=
pack_id
*
pack_size
;
if
(
!
padding
||
col
<
cols
)
{
load
.
template
load
<
pack_size
>(
row_buf
+
pack_offset
,
row
+
row_id
,
col
);
LoadType
pack
[
pack_size
];
load
.
template
load
<
pack_size
>(
pack
,
row
+
row_id
,
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
row_buf
[
pack_offset
+
i
]
=
static_cast
<
ComputeType
>
(
pack
[
i
]);
WelfordCombine
(
row_buf
[
pack_offset
+
i
],
thread_mean
+
row_id
,
thread_m2
+
row_id
,
thread_count
+
row_id
);
}
...
...
@@ -336,11 +418,16 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
inv_variance
[
global_row_id
]
=
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_
cols_per_thread
;
++
i
)
{
row_buf
[
i
]
=
(
row_buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
min_num_packs
;
++
i
)
{
const
int
col
=
(
i
*
thread_group_width
+
lane_id
)
*
pack_size
;
store
.
template
store
<
pack_size
>(
row_buf
+
i
*
pack_size
,
global_row_id
,
col
);
}
#pragma unroll
for
(
int
i
=
min_num_packs
;
i
<
max_num_packs
;
++
i
)
{
const
int
col
=
(
i
*
thread_group_width
+
lane_id
)
*
pack_size
;
if
(
!
padding
||
col
<
cols
)
{
store
.
template
store
<
pack_size
>(
row_buf
+
i
*
pack_size
,
global_row_id
,
col
);
...
...
@@ -350,9 +437,10 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
inline
cudaError_t
LaunchLayerNormWarpImpl
(
cudaStream_t
stream
,
LOAD
load
,
STORE
store
,
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
inline
GPU
(
Error_t
)
LaunchLayerNormWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
...
...
@@ -365,171 +453,129 @@ inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE
(
rows
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
LayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_
cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
LayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
LayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_
cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
inline
cudaError_t
DispatchLayerNormWarpImplPadding
(
cudaStream_t
stream
,
LOAD
load
,
STORE
store
,
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
inline
GPU
(
Error_t
)
DispatchLayerNormWarpImplPadding
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
==
cols_per_thread
*
thread_group_width
)
{
return
LaunchLayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
false
>
(
if
(
cols
==
max_cols_per_thread
*
thread_group_width
)
{
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return
LaunchLayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
max_cols_per_thread
,
thread_group_width
,
rows_per_access
,
false
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
else
{
return
LaunchLayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
true
>
(
return
LaunchLayerNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_
cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
true
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
1
,
cuda
Error_t
>::
type
DispatchLayerNormWarpImplCols
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchLayerNormWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width)
\
else if (cols <= (thread_group_width)*pack_size) {
\
if (rows % 2 == 0) {
\
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,
0,
\
thread_group_width, 2>(
\
stream, load, store, rows, cols, epsilon, mean, inv_variance);
\
} else {
\
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,
0,
\
thread_group_width, 1>(
\
stream, load, store, rows, cols, epsilon, mean, inv_variance);
\
}
\
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(
col)
\
else if (cols <= (col)*kWarpSize) {
\
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col,
kWarpSize
, \
1>(stream, load, store, rows, cols,
epsilon, mean,
\
inv_variance);
\
#define DEFINE_ONE_ELIF(
max_col, min_col)
\
else if (cols <= (
max_
col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size,
max_
col,
min_col
, \
kWarpSize,
1>(stream, load, store, rows, cols,
\
epsilon, mean, inv_variance);
\
}
DEFINE_ONE_ELIF
(
2
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
12
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
20
)
DEFINE_ONE_ELIF
(
24
)
DEFINE_ONE_ELIF
(
28
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
2
,
1
)
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
2
,
cuda
Error_t
>::
type
DispatchLayerNormWarpImplCols
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
typename
std
::
enable_if
<
pack_size
==
2
,
GPU
(
Error_t
)
>::
type
DispatchLayerNormWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
<=
0
)
{
return
cudaErrorInvalidValue
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1>(stream, load, store, rows, cols, epsilon, mean, \
inv_variance); \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
12
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
20
)
DEFINE_ONE_ELIF
(
24
)
DEFINE_ONE_ELIF
(
28
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cudaErrorInvalidValue
;
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
);
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
4
,
cudaError_t
>::
type
DispatchLayerNormWarpImplCols
(
cudaStream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
<=
0
)
{
return
cudaErrorInvalidValue
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF
(
1
)
DEFINE_ONE_ELIF
(
2
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(
col)
\
else if (cols <= (col)*kWarpSize)
{
\
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col,
kWarpSize
, \
1>(stream, load, store, rows, cols,
epsilon, mean,
\
inv_variance);
\
#define DEFINE_ONE_ELIF(
max_col, min_col)
\
else if
(
(cols <= (
max_
col)*kWarpSize)
&& (cols > (min_col)*kWarpSize)) {
\
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size,
max_
col,
min_col
, \
kWarpSize,
1>(stream, load, store, rows, cols,
\
epsilon, mean, inv_variance);
\
}
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
12
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
20
)
DEFINE_ONE_ELIF
(
24
)
DEFINE_ONE_ELIF
(
28
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
struct
DispatchLayerNormWarpImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
%
4
==
0
)
{
return
DispatchLayerNormWarpImplCols
<
LOAD
,
STORE
,
ComputeType
,
4
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
else
if
(
cols
%
2
==
0
)
{
if
(
cols
%
2
==
0
&&
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
DispatchLayerNormWarpImplCols
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
else
{
...
...
@@ -540,7 +586,7 @@ struct DispatchLayerNormWarpImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
DispatchLayerNormWarpImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
DispatchLayerNormWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
...
...
@@ -552,8 +598,9 @@ template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int
__global__
void
LayerNormBlockSMemImpl
(
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
using
LoadType
=
typename
LOAD
::
LoadType
;
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
Compute
Type
*>
(
shared_buf
);
auto
*
buf
=
reinterpret_cast
<
Load
Type
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
assert
(
cols
%
pack_size
==
0
);
const
int
num_packs
=
static_cast
<
int
>
(
cols
)
/
pack_size
;
...
...
@@ -562,12 +609,12 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
ComputeType
thread_m2
=
0
;
ComputeType
thread_count
=
0
;
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
Compute
Type
pack
[
pack_size
];
Load
Type
pack
[
pack_size
];
load
.
template
load
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
buf
[
i
*
num_packs
+
pack_id
]
=
pack
[
i
];
WelfordCombine
(
pack
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
WelfordCombine
(
static_cast
<
ComputeType
>
(
pack
[
i
]
)
,
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
}
ComputeType
row_mean
=
0
;
...
...
@@ -585,7 +632,7 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
ComputeType
pack
[
pack_size
];
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
=
(
buf
[
i
*
num_packs
+
pack_id
]
-
row_mean
)
*
row_inv_var
;
pack
[
i
]
=
(
static_cast
<
ComputeType
>
(
buf
[
i
*
num_packs
+
pack_id
]
)
-
row_mean
)
*
row_inv_var
;
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
}
...
...
@@ -593,88 +640,152 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
inline
cuda
Error_t
LaunchLayerNormBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
LaunchLayerNormBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
int
smem
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
smem
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
smem
,
stream
>>>
(
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
return
cudaPeekAtLastError
();
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
Func
>
GPU
(
Error_t
)
MaximizeDynamicSharedMemorySize
(
Func
func
,
const
int
max_smem_size
)
{
GPU
(
FuncAttributes
)
attr
{};
#ifdef WITH_ROCM
GPU
(
Error_t
)
err
=
GPU
(
FuncGetAttributes
)(
&
attr
,
(
const
void
*
)
func
);
#else
GPU
(
Error_t
)
err
=
GPU
(
FuncGetAttributes
)(
&
attr
,
func
);
#endif
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
constexpr
int
reserved_smem
=
1024
;
// 1K
#ifdef WITH_ROCM
return
GPU
(
FuncSetAttribute
)((
const
void
*
)
func
,
GPU
(
FuncAttributeMaxDynamicSharedMemorySize
),
max_smem_size
-
attr
.
sharedSizeBytes
-
reserved_smem
);
#else
return
GPU
(
FuncSetAttribute
)(
func
,
GPU
(
FuncAttributeMaxDynamicSharedMemorySize
),
max_smem_size
-
attr
.
sharedSizeBytes
-
reserved_smem
);
#endif
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
inline
cuda
Error_t
TryDispatchLayerNormBlockSMemImplBlockSize
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
inline
GPU
(
Error_t
)
TryDispatchLayerNormBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
,
bool
*
success
)
{
constexpr
int
block_size_conf_1
=
128
;
constexpr
int
block_size_conf_2
=
256
;
constexpr
int
block_size_conf_3
=
512
;
constexpr
int
block_size_conf_4
=
1024
;
const
size_t
smem
=
cols
*
sizeof
(
ComputeType
);
int
max_active_blocks_conf_1
;
int
dev
=
0
;
{
GPU
(
Error_t
)
err
=
GPU
(
GetDevice
)(
&
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
int
sm_count
=
0
;
{
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)(
&
sm_count
,
GPUMultiProcessorCount
,
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
static
const
bool
max_smem_configed
=
[
=
]()
{
int
max_smem_size
=
0
;
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)(
&
max_smem_size
,
GPUMaxSharedMemoryPerBlockOptin
,
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
return
true
;
}();
const
size_t
smem
=
cols
*
sizeof
(
typename
LOAD
::
LoadType
);
int
max_active_blocks_conf_1
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_1
,
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
,
block_size_conf_1
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_1
<=
0
)
{
*
success
=
false
;
return
cuda
Success
;
return
GPU
(
Success
)
;
}
int
max_active_blocks_conf_4
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_4
,
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
,
block_size_conf_4
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_4
>
0
&&
rows
<=
sm_count
))
{
*
success
=
true
;
return
LaunchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
(
stream
,
load
,
store
,
smem
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
int
max_active_blocks_conf_3
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_3
,
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
,
block_size_conf_3
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_3
>
0
&&
rows
<=
sm_count
)
)
{
*
success
=
true
;
return
LaunchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
(
stream
,
load
,
store
,
smem
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
int
max_active_blocks_conf_2
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_2
,
LayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
,
block_size_conf_2
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_2
>
0
&&
rows
<=
sm_count
)
)
{
*
success
=
true
;
return
LaunchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
(
stream
,
load
,
store
,
smem
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
*
success
=
true
;
return
LaunchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
(
stream
,
load
,
store
,
smem
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
...
...
@@ -682,13 +793,13 @@ inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
struct
TryDispatchLayerNormBlockSMemImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
,
bool
*
success
)
{
if
(
cols
%
4
==
0
)
{
if
(
cols
%
4
==
0
&&
CanPackAs
<
LOAD
>
(
load
,
4
)
&&
CanPackAs
<
STORE
>
(
store
,
4
)
)
{
return
TryDispatchLayerNormBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
4
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
,
success
);
}
else
if
(
cols
%
2
==
0
)
{
}
else
if
(
cols
%
2
==
0
&&
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
CanPackAs
<
STORE
>
(
store
,
2
)
)
{
return
TryDispatchLayerNormBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
,
success
);
}
else
{
...
...
@@ -699,7 +810,7 @@ struct TryDispatchLayerNormBlockSMemImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
TryDispatchLayerNormBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
TryDispatchLayerNormBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
,
bool
*
success
)
{
...
...
@@ -708,9 +819,10 @@ inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD l
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
__global__
void
LayerNormBlockUncachedImpl
(
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
__global__
void
__launch_bounds__
(
1024
)
LayerNormBlockUncachedImpl
(
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
using
LoadType
=
typename
LOAD
::
LoadType
;
const
int
tid
=
threadIdx
.
x
;
assert
(
cols
%
pack_size
==
0
);
const
int
num_packs
=
static_cast
<
int
>
(
cols
)
/
pack_size
;
...
...
@@ -719,11 +831,11 @@ __global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t
ComputeType
thread_m2
=
0
;
ComputeType
thread_count
=
0
;
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
Compute
Type
pack
[
pack_size
];
Load
Type
pack
[
pack_size
];
load
.
template
load
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
WelfordCombine
(
pack
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
WelfordCombine
(
static_cast
<
ComputeType
>
(
pack
[
i
]
)
,
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
}
ComputeType
row_mean
=
0
;
...
...
@@ -738,18 +850,21 @@ __global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t
inv_variance
[
row
]
=
row_inv_var
;
}
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
ComputeType
pack
[
pack_size
];
LoadType
pack
[
pack_size
];
ComputeType
dst_pack
[
pack_size
];
const
int
pack_offset
=
pack_id
*
pack_size
;
load
.
template
load
<
pack_size
>(
pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
=
(
pack
[
i
]
-
row_mean
)
*
row_inv_var
;
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_offset
);
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
dst_pack
[
i
]
=
(
static_cast
<
ComputeType
>
(
pack
[
i
])
-
row_mean
)
*
row_inv_var
;
}
store
.
template
store
<
pack_size
>(
dst_pack
,
row
,
pack_offset
);
}
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
inline
cuda
Error_t
LaunchLayerNormBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
LaunchLayerNormBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
...
...
@@ -757,25 +872,25 @@ inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD lo
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
0
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
LayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
struct
DispatchLayerNormBlockUncachedImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
%
4
==
0
)
{
if
(
cols
%
4
==
0
&&
CanPackAs
<
LOAD
>
(
load
,
4
)
&&
CanPackAs
<
STORE
>
(
store
,
4
)
)
{
return
LaunchLayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
4
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
else
if
(
cols
%
2
==
0
)
{
}
else
if
(
cols
%
2
==
0
&&
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
CanPackAs
<
STORE
>
(
store
,
2
)
)
{
return
LaunchLayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
else
{
...
...
@@ -786,7 +901,7 @@ struct DispatchLayerNormBlockUncachedImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
DispatchLayerNormBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
DispatchLayerNormBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
...
...
@@ -795,8 +910,8 @@ inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLayerNorm
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLayerNorm
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
if
(
cols
<=
1024
)
{
...
...
@@ -805,22 +920,22 @@ DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t row
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
TryDispatchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
>
(
GPU
(
Error_t
)
err
=
TryDispatchLayerNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchLayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
>
(
stream
,
load
,
store
,
rows
,
cols
,
epsilon
,
mean
,
inv_variance
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLayerNorm
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLayerNorm
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
return
DispatchLayerNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
>
(
...
...
@@ -836,18 +951,22 @@ dx = cols * dy - sum_stats1 - normalized * sum_stats2
dx *= inv_var / cols
*/
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
int
pack_size
,
int
max_cols_per_thread
,
int
min_
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
__global__
void
LayerNormGradWarpImpl
(
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
static_assert
(
cols_per_thread
%
pack_size
==
0
,
""
);
constexpr
int
pack_per_thread
=
cols_per_thread
/
pack_size
;
assert
(
cols
<=
cols_per_thread
*
thread_group_width
);
using
LoadTypeX
=
typename
LOAD_X
::
LoadType
;
using
LoadTypeDy
=
typename
LOAD_SCALED_DY
::
LoadType
;
static_assert
(
max_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
min_cols_per_thread
%
pack_size
==
0
,
""
);
constexpr
int
max_num_packs
=
max_cols_per_thread
/
pack_size
;
constexpr
int
min_num_packs
=
min_cols_per_thread
/
pack_size
;
assert
(
cols
<=
max_cols_per_thread
*
thread_group_width
);
static_assert
(
thread_group_width
<=
kWarpSize
,
""
);
static_assert
(
kWarpSize
%
thread_group_width
==
0
,
""
);
ComputeType
normalized_buf
[
rows_per_access
][
cols_per_thread
];
ComputeType
dy_buf
[
rows_per_access
][
cols_per_thread
];
ComputeType
normalized_buf
[
rows_per_access
][
max_
cols_per_thread
];
ComputeType
dy_buf
[
rows_per_access
][
max_
cols_per_thread
];
const
ComputeType
one_over_cols
=
static_cast
<
ComputeType
>
(
1.0
)
/
static_cast
<
ComputeType
>
(
cols
);
const
int64_t
global_thread_group_id
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int64_t
num_global_thread_group
=
gridDim
.
x
*
blockDim
.
y
;
...
...
@@ -867,18 +986,40 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
ComputeType
*
row_normalized_buf
=
normalized_buf
[
row_id
];
ComputeType
*
row_dy_buf
=
dy_buf
[
row_id
];
#pragma unroll
for
(
int
pack_id
=
0
;
pack_id
<
pack_per_thread
;
++
pack_id
)
{
for
(
int
pack_id
=
0
;
pack_id
<
min_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
const
int
pack_offset
=
pack_id
*
pack_size
;
if
(
!
padding
||
col
<
cols
)
{
load_x
.
template
load
<
pack_size
>(
row_normalized_buf
+
pack_offset
,
global_row_id
,
col
);
load_scaled_dy
.
template
load
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row_id
,
col
);
LoadTypeX
pack_x
[
pack_size
];
LoadTypeDy
pack_dy
[
pack_size
];
load_x
.
template
load
<
pack_size
>(
pack_x
,
global_row_id
,
col
);
load_scaled_dy
.
template
load
<
pack_size
>(
pack_dy
,
global_row_id
,
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
col_id
=
pack_offset
+
i
;
// row_normalized_buf store x
row_normalized_buf
[
col_id
]
=
(
static_cast
<
ComputeType
>
(
pack_x
[
i
])
-
mean_val
)
*
inv_variance_buf
[
row_id
];
row_dy_buf
[
col_id
]
=
static_cast
<
ComputeType
>
(
pack_dy
[
i
]);
sum_stats1
[
row_id
]
+=
row_dy_buf
[
col_id
];
sum_stats2
[
row_id
]
+=
row_dy_buf
[
col_id
]
*
row_normalized_buf
[
col_id
];
}
}
#pragma unroll
for
(
int
pack_id
=
min_num_packs
;
pack_id
<
max_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
const
int
pack_offset
=
pack_id
*
pack_size
;
if
(
col
<
cols
)
{
LoadTypeX
pack_x
[
pack_size
];
LoadTypeDy
pack_dy
[
pack_size
];
load_x
.
template
load
<
pack_size
>(
pack_x
,
global_row_id
,
col
);
load_scaled_dy
.
template
load
<
pack_size
>(
pack_dy
,
global_row_id
,
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
col_id
=
pack_offset
+
i
;
// row_normalized_buf store x
row_normalized_buf
[
col_id
]
=
(
row_normalized_buf
[
col_id
]
-
mean_val
)
*
inv_variance_buf
[
row_id
];
(
static_cast
<
ComputeType
>
(
pack_x
[
i
])
-
mean_val
)
*
inv_variance_buf
[
row_id
];
row_dy_buf
[
col_id
]
=
static_cast
<
ComputeType
>
(
pack_dy
[
i
]);
sum_stats1
[
row_id
]
+=
row_dy_buf
[
col_id
];
sum_stats2
[
row_id
]
+=
row_dy_buf
[
col_id
]
*
row_normalized_buf
[
col_id
];
}
...
...
@@ -901,16 +1042,29 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
ComputeType
*
row_dy_buf
=
dy_buf
[
row_id
];
const
ComputeType
inv_variance_over_cols
=
inv_variance_buf
[
row_id
]
*
one_over_cols
;
#pragma unroll
for
(
int
pack_id
=
0
;
pack_id
<
pack_per_thread
;
++
pack_id
)
{
for
(
int
pack_id
=
0
;
pack_id
<
min_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
if
(
!
padding
||
col
<
cols
)
{
const
int
pack_offset
=
pack_id
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
col_id
=
pack_offset
+
i
;
row_dy_buf
[
col_id
]
=
(
cols
*
row_dy_buf
[
col_id
]
-
warp_sum_stats1
[
row_id
]
-
row_normalized_buf
[
col_id
]
*
warp_sum_stats2
[
row_id
])
*
inv_variance_over_cols
;
}
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row_id
,
col
);
}
#pragma unroll
for
(
int
pack_id
=
min_num_packs
;
pack_id
<
max_num_packs
;
++
pack_id
)
{
const
int
col
=
(
pack_id
*
thread_group_width
+
lane_id
)
*
pack_size
;
if
(
col
<
cols
)
{
const
int
pack_offset
=
pack_id
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
col_id
=
pack_
id
*
pack_size
+
i
;
const
int
col_id
=
pack_
offset
+
i
;
row_dy_buf
[
col_id
]
=
(
cols
*
row_dy_buf
[
col_id
]
-
warp_sum_stats1
[
row_id
]
-
row_normalized_buf
[
col_id
]
*
warp_sum_stats2
[
row_id
])
*
inv_variance_over_cols
;
}
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_
id
*
pack_size
,
global_row_id
,
col
);
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_
offset
,
global_row_id
,
col
);
}
}
}
...
...
@@ -918,9 +1072,9 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
inline
cuda
Error_t
LaunchLayerNormGradWarpImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
inline
GPU
(
Error_t
)
LaunchLayerNormGradWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
...
...
@@ -934,143 +1088,100 @@ inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_
(
rows
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
cudaError_t
err
=
GetNumBlocks
(
LayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
LayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
LayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
max_
cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
inline
cudaError_t
DispatchLayerNormGradWarpImplPadding
(
cudaStream_t
stream
,
LOAD_X
load_x
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
inline
GPU
(
Error_t
)
DispatchLayerNormGradWarpImplPadding
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
==
cols_per_thread
*
thread_group_width
)
{
if
(
cols
==
max_cols_per_thread
*
thread_group_width
)
{
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return
LaunchLayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
false
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
max_cols_per_thread
,
max_cols_per_thread
,
thread_group_width
,
rows_per_access
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
else
{
return
LaunchLayerNormGradWarpImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
true
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
1
,
cudaError_t
>::
type
DispatchLayerNormGradWarpImplCols
(
cudaStream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cudaErrorInvalidValue
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF
(
2
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
12
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
20
)
DEFINE_ONE_ELIF
(
24
)
DEFINE_ONE_ELIF
(
28
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cudaErrorInvalidValue
;
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
2
,
cuda
Error_t
>::
type
DispatchLayerNormGradWarpImplCols
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchLayerNormGradWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width)
\
else if (cols <= (thread_group_width)*pack_size) {
\
if (rows % 2 == 0) {
\
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
\
pack_size, pack_size,
0,
thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
\
} else {
\
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
\
pack_size, pack_size,
0,
thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
\
}
\
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(
col)
\
else if (cols <= (col)*kWarpSize) {
\
#define DEFINE_ONE_ELIF(
max_col, min_col)
\
else if (cols <= (
max_
col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>(
\
pack_size,
max_col, min_
col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
12
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
20
)
DEFINE_ONE_ELIF
(
24
)
DEFINE_ONE_ELIF
(
28
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
2
,
1
)
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
struct
DispatchLayerNormGradWarpImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
)
{
return
DispatchLayerNormGradWarpImplCols
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
2
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
else
{
return
DispatchLayerNormGradWarpImplCols
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
return
DispatchLayerNormGradWarpImplCols
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
};
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
DispatchLayerNormGradWarpImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
inline
GPU
(
Error_t
)
DispatchLayerNormGradWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
...
...
@@ -1085,9 +1196,11 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
using
LoadTypeX
=
typename
LOAD_X
::
LoadType
;
using
LoadTypeDy
=
typename
LOAD_SCALED_DY
::
LoadType
;
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
grad_shared_buf
[];
auto
*
normalized_buf
=
reinterpret_cast
<
Compute
Type
*>
(
grad_shared_buf
);
auto
*
dy_buf
=
normalized_buf
+
cols
;
auto
*
normalized_buf
=
reinterpret_cast
<
Load
Type
X
*>
(
grad_shared_buf
);
auto
*
dy_buf
=
reinterpret_cast
<
LoadTypeDy
*>
(
normalized_buf
+
cols
)
;
const
int
tid
=
threadIdx
.
x
;
assert
(
cols
%
pack_size
==
0
);
const
int
num_packs
=
static_cast
<
int
>
(
cols
)
/
pack_size
;
...
...
@@ -1099,18 +1212,19 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
const
ComputeType
inv_variance_val
=
inv_variance
[
row
];
const
ComputeType
inv_variance_over_cols
=
inv_variance_val
*
one_over_cols
;
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
Compute
Type
x_pack
[
pack_size
];
Compute
Type
dy_pack
[
pack_size
];
Load
Type
X
x_pack
[
pack_size
];
Load
Type
Dy
dy_pack
[
pack_size
];
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_id
*
pack_size
);
load_scaled_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_id
*
pack_size
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
buf_offset
=
i
*
num_packs
+
pack_id
;
ComputeType
normalized
=
(
x_pack
[
i
]
-
mean_val
)
*
inv_variance_val
;
normalized_buf
[
buf_offset
]
=
normalized
;
ComputeType
normalized
=
(
static_cast
<
ComputeType
>
(
x_pack
[
i
])
-
mean_val
)
*
inv_variance_val
;
normalized_buf
[
buf_offset
]
=
static_cast
<
LoadTypeX
>
(
normalized
);
dy_buf
[
buf_offset
]
=
dy_pack
[
i
];
sum_stats1
+=
dy_pack
[
i
];
sum_stats2
+=
dy_pack
[
i
]
*
normalized
;
sum_stats1
+=
static_cast
<
ComputeType
>
(
dy_pack
[
i
]
)
;
sum_stats2
+=
static_cast
<
ComputeType
>
(
dy_pack
[
i
]
)
*
normalized
;
}
}
const
ComputeType
row_sum_stats1
=
BlockAllReduce
<
SumOp
,
ComputeType
,
block_size
>
(
sum_stats1
);
...
...
@@ -1120,8 +1234,8 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
const
int
buf_offset
=
i
*
num_packs
+
pack_id
;
pack
[
i
]
=
(
cols
*
dy_buf
[
buf_offset
]
-
row_sum_stats1
-
normalized_buf
[
buf_offset
]
*
row_sum_stats2
)
pack
[
i
]
=
(
cols
*
static_cast
<
ComputeType
>
(
dy_buf
[
buf_offset
]
)
-
row_sum_stats1
-
static_cast
<
ComputeType
>
(
normalized_buf
[
buf_offset
]
)
*
row_sum_stats2
)
*
inv_variance_over_cols
;
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
...
...
@@ -1131,7 +1245,7 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
inline
cuda
Error_t
LaunchLayerNormGradBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
inline
GPU
(
Error_t
)
LaunchLayerNormGradBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
int
smem
,
...
...
@@ -1139,86 +1253,139 @@ inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
smem
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
smem
,
stream
>>>
(
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
inline
cuda
Error_t
TryDispatchLayerNormGradBlockSMemImplBlockSize
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
inline
GPU
(
Error_t
)
TryDispatchLayerNormGradBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
constexpr
int
block_size_conf_1
=
128
;
constexpr
int
block_size_conf_2
=
256
;
constexpr
int
block_size_conf_3
=
512
;
constexpr
int
block_size_conf_4
=
1024
;
const
size_t
smem
=
cols
*
sizeof
(
ComputeType
)
*
2
;
int
dev
=
0
;
{
GPU
(
Error_t
)
err
=
GPU
(
GetDevice
)(
&
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
int
sm_count
=
0
;
{
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)(
&
sm_count
,
GPUMultiProcessorCount
,
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
static
const
bool
max_smem_configed
=
[
=
]()
{
int
max_smem_size
=
0
;
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)(
&
max_smem_size
,
GPUMaxSharedMemoryPerBlockOptin
,
dev
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
err
=
MaximizeDynamicSharedMemorySize
(
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
,
max_smem_size
);
if
(
err
!=
GPU
(
Success
))
{
return
false
;
}
return
true
;
}();
using
LoadTypeX
=
typename
LOAD_X
::
LoadType
;
using
LoadTypeDy
=
typename
LOAD_SCALED_DY
::
LoadType
;
const
size_t
smem
=
cols
*
(
sizeof
(
LoadTypeX
)
+
sizeof
(
LoadTypeDy
));
int
max_active_blocks_conf_1
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_1
,
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
,
block_size_conf_1
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_1
<=
0
)
{
*
success
=
false
;
return
cuda
Success
;
return
GPU
(
Success
)
;
}
int
max_active_blocks_conf_4
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_4
,
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
,
block_size_conf_4
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_4
>
0
&&
rows
<=
sm_count
))
{
*
success
=
true
;
return
LaunchLayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
smem
,
rows
,
cols
);
}
int
max_active_blocks_conf_3
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_3
,
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
,
block_size_conf_3
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_3
>
0
&&
rows
<=
sm_count
))
{
*
success
=
true
;
return
LaunchLayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
smem
,
rows
,
cols
);
}
int
max_active_blocks_conf_2
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_2
,
LayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
,
block_size_conf_2
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
)
{
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
||
(
max_active_blocks_conf_2
>
0
&&
rows
<=
sm_count
))
{
*
success
=
true
;
return
LaunchLayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
smem
,
rows
,
cols
);
}
*
success
=
true
;
return
LaunchLayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
...
...
@@ -1227,10 +1394,11 @@ inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
struct
TryDispatchLayerNormGradBlockSMemImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
if
(
cols
%
2
==
0
)
{
if
(
cols
%
2
==
0
&&
CanPackAs
<
LOAD_X
>
(
load_x
,
2
)
&&
CanPackAs
<
LOAD_SCALED_DY
>
(
load_scaled_dy
,
2
)
&&
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
TryDispatchLayerNormGradBlockSMemImplBlockSize
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
2
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
,
success
);
...
...
@@ -1243,7 +1411,7 @@ struct TryDispatchLayerNormGradBlockSMemImplPackSize {
};
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
TryDispatchLayerNormGradBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
inline
GPU
(
Error_t
)
TryDispatchLayerNormGradBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
...
...
@@ -1260,6 +1428,8 @@ __global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY loa
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
using
LoadTypeX
=
typename
LOAD_X
::
LoadType
;
using
LoadTypeDy
=
typename
LOAD_SCALED_DY
::
LoadType
;
const
int
tid
=
threadIdx
.
x
;
assert
(
cols
%
pack_size
==
0
);
const
int
num_packs
=
static_cast
<
int
>
(
cols
)
/
pack_size
;
...
...
@@ -1271,75 +1441,134 @@ __global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY loa
ComputeType
sum_stats1
=
0
;
ComputeType
sum_stats2
=
0
;
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
ComputeType
x_pack
[
pack_size
]
;
Compute
Type
dy
_pack
[
pack_size
];
l
oad
_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_id
*
pack_size
)
;
load_
scaled_dy
.
template
load
<
pack_size
>(
dy
_pack
,
row
,
pack_
id
*
pack_size
);
const
int
pack_offset
=
pack_id
*
pack_size
;
Load
Type
X
x
_pack
[
pack_size
];
L
oad
TypeDy
dy_pack
[
pack_size
]
;
load_
x
.
template
load
<
pack_size
>(
x
_pack
,
row
,
pack_
offset
);
load_scaled_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
sum_stats1
+=
dy_pack
[
i
];
sum_stats2
+=
dy_pack
[
i
]
*
(
x_pack
[
i
]
-
mean_val
)
*
inv_variance_val
;
sum_stats1
+=
static_cast
<
ComputeType
>
(
dy_pack
[
i
]);
sum_stats2
+=
static_cast
<
ComputeType
>
(
dy_pack
[
i
])
*
(
static_cast
<
ComputeType
>
(
x_pack
[
i
])
-
mean_val
)
*
inv_variance_val
;
}
}
const
ComputeType
row_sum_stats1
=
BlockAllReduce
<
SumOp
,
ComputeType
,
block_size
>
(
sum_stats1
);
const
ComputeType
row_sum_stats2
=
BlockAllReduce
<
SumOp
,
ComputeType
,
block_size
>
(
sum_stats2
);
for
(
int
pack_id
=
tid
;
pack_id
<
num_packs
;
pack_id
+=
block_size
)
{
ComputeType
x_pack
[
pack_size
];
ComputeType
dy_pack
[
pack_size
];
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_id
*
pack_size
);
load_scaled_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_id
*
pack_size
);
const
int
pack_offset
=
pack_id
*
pack_size
;
LoadTypeX
x_pack
[
pack_size
];
LoadTypeDy
dy_pack
[
pack_size
];
ComputeType
dx_pack
[
pack_size
];
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_offset
);
load_scaled_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
dy_pack
[
i
]
=
(
cols
*
dy_pack
[
i
]
-
row_sum_stats1
-
(
x_pack
[
i
]
-
mean_val
)
*
inv_variance_val
*
row_sum_stats2
)
*
inv_variance_over_cols
;
dx_pack
[
i
]
=
(
cols
*
static_cast
<
ComputeType
>
(
dy_pack
[
i
])
-
row_sum_stats1
-
(
static_cast
<
ComputeType
>
(
x_pack
[
i
])
-
mean_val
)
*
inv_variance_val
*
row_sum_stats2
)
*
inv_variance_over_cols
;
}
store
.
template
store
<
pack_size
>(
d
y
_pack
,
row
,
pack_
id
*
pack_size
);
store
.
template
store
<
pack_size
>(
d
x
_pack
,
row
,
pack_
offset
);
}
}
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
inline
cuda
Error_t
LaunchLayerNormGradBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
int
pack_size
,
int
block_size
>
inline
GPU
(
Error_t
)
LaunchLayerNormGradBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
GetNumBlocks
(
LayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
0
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
LayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
return
cudaPeekAtLastError
();
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
inline
GPU
(
Error_t
)
TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
int
max_active_blocks
=
0
;
constexpr
int
block_size_conf_1
=
1024
;
{
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)(
&
max_active_blocks
,
LayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
,
block_size_conf_1
,
0
);
if
(
max_active_blocks
>
0
)
{
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
constexpr
int
block_size_conf_2
=
512
;
{
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)(
&
max_active_blocks
,
LayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
,
block_size_conf_2
,
0
);
if
(
max_active_blocks
>
0
)
{
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
constexpr
int
block_size_conf_3
=
256
;
{
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)(
&
max_active_blocks
,
LayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
,
block_size_conf_2
,
0
);
if
(
max_active_blocks
>
0
)
{
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
constexpr
int
block_size_conf_4
=
128
;
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
struct
DispatchLayerNormGradBlockUncachedImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
&&
cols
>
kWarpSize
)
{
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
2
>
(
if
(
cols
%
2
==
0
&&
CanPackAs
<
LOAD_X
>
(
load_x
,
2
)
&&
CanPackAs
<
LOAD_SCALED_DY
>
(
load_scaled_dy
,
2
)
&&
CanPackAs
<
STORE
>
(
store
,
2
)
&&
cols
>
kWarpSize
)
{
return
TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
2
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
else
{
return
LaunchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
1
>
(
return
TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
}
};
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
inline
cuda
Error_t
DispatchLayerNormGradBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
inline
GPU
(
Error_t
)
DispatchLayerNormGradBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
...
...
@@ -1350,8 +1579,8 @@ inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, L
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLayerNormGrad
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLayerNormGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
1024
)
{
...
...
@@ -1360,23 +1589,23 @@ DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_sc
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
TryDispatchLayerNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
>
(
stream
,
load_x
,
load_scaled_dy
,
store
,
mean
,
inv_variance
,
rows
,
cols
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD_X
,
typename
LOAD_SCALED_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLayerNormGrad
(
cuda
Stream_t
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLayerNormGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_X
load_x
,
LOAD_SCALED_DY
load_scaled_dy
,
STORE
store
,
const
ComputeType
*
mean
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchLayerNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_SCALED_DY
,
STORE
,
ComputeType
>
(
...
...
oneflow/core/cuda/rms_norm.cuh
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_CUDA_RMS_NORM_H_
#define ONEFLOW_CORE_CUDA_RMS_NORM_H_
#include "oneflow/core/cuda/layer_norm.cuh"
namespace
oneflow
{
namespace
cuda
{
namespace
rms_norm
{
#ifdef WITH_ROCM
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kWarpSize
=
32
;
#endif
template
<
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
)
{
#ifdef WITH_ROCM
for
(
int
mask
=
32
;
mask
>
0
;
mask
/=
2
)
{
val
+=
__shfl_down
(
val
,
mask
);
}
#else
for
(
int
mask
=
16
;
mask
>
0
;
mask
/=
2
)
{
val
+=
__shfl_down_sync
(
0xffffffff
,
val
,
mask
);
}
#endif
return
val
;
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
__global__
void
RmsNormWarpImpl
(
LOAD
load
,
STORE
store
,
const
int
nrow
,
const
int
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
static_assert
(
max_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
min_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
thread_group_width
<=
kWarpSize
,
""
);
static_assert
(
kWarpSize
%
thread_group_width
==
0
,
""
);
constexpr
int
max_packs
=
max_cols_per_thread
/
pack_size
;
constexpr
int
min_packs
=
min_cols_per_thread
/
pack_size
;
assert
(
ncol
<=
max_cols_per_thread
*
thread_group_width
);
ComputeType
buf
[
rows_per_access
][
max_cols_per_thread
];
const
int
global_thread_group_id
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
num_global_thread_groups
=
gridDim
.
x
*
blockDim
.
y
;
for
(
int
row_i
=
global_thread_group_id
;
row_i
<
nrow
;
row_i
+=
num_global_thread_groups
)
{
ComputeType
thread_square_sum
[
rows_per_access
];
#pragma unroll
for
(
int
row_j
=
0
;
row_j
<
rows_per_access
;
++
row_j
)
{
thread_square_sum
[
row_j
]
=
0
;
ComputeType
*
row_buf
=
buf
[
row_j
];
const
int
row
=
row_i
*
rows_per_access
+
row_j
;
#pragma unroll
for
(
int
pack_i
=
0
;
pack_i
<
min_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
load
.
template
load
<
pack_size
>(
row_buf
+
pack_offset
,
row
,
col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
thread_square_sum
[
row_j
]
+=
row_buf
[
pack_offset
+
pack_j
]
*
row_buf
[
pack_offset
+
pack_j
];
}
}
#pragma unroll
for
(
int
pack_i
=
min_packs
;
pack_i
<
max_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
if
(
!
padding
||
col
<
ncol
)
{
load
.
template
load
<
pack_size
>(
row_buf
+
pack_offset
,
row
,
col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
thread_square_sum
[
row_j
]
+=
row_buf
[
pack_offset
+
pack_j
]
*
row_buf
[
pack_offset
+
pack_j
];
}
}
else
{
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
row_buf
[
pack_i
*
pack_size
+
pack_j
]
=
0
;
}
}
}
}
ComputeType
warp_square_sum
[
rows_per_access
];
#pragma unroll
for
(
int
row_j
=
0
;
row_j
<
rows_per_access
;
++
row_j
)
{
const
int
row
=
row_i
*
rows_per_access
+
row_j
;
ComputeType
*
row_buf
=
buf
[
row_j
];
warp_square_sum
[
row_j
]
=
layer_norm
::
WarpAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
thread_group_width
>
(
thread_square_sum
[
row_j
]);
ComputeType
row_square_mean
=
layer_norm
::
Div
(
warp_square_sum
[
row_j
],
static_cast
<
ComputeType
>
(
ncol
));
ComputeType
row_inv_rms
=
layer_norm
::
Rsqrt
(
row_square_mean
+
static_cast
<
ComputeType
>
(
eps
));
if
(
threadIdx
.
x
==
0
)
{
inv_rms
[
row
]
=
row_inv_rms
;
}
#pragma unroll
for
(
int
col
=
0
;
col
<
max_cols_per_thread
;
++
col
)
{
row_buf
[
col
]
*=
row_inv_rms
;
}
#pragma unroll
for
(
int
pack_i
=
0
;
pack_i
<
min_packs
;
++
pack_i
)
{
const
int
col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
store
.
template
store
<
pack_size
>(
row_buf
+
pack_i
*
pack_size
,
row
,
col
);
}
#pragma unroll
for
(
int
pack_i
=
min_packs
;
pack_i
<
max_packs
;
++
pack_i
)
{
const
int
col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
if
(
!
padding
||
col
<
ncol
)
{
store
.
template
store
<
pack_size
>(
row_buf
+
pack_i
*
pack_size
,
row
,
col
);
}
}
}
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
>
GPU
(
Error_t
)
LaunchRmsNormWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
const
int64_t
num_blocks
=
(
nrow
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
RmsNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
load
,
store
,
static_cast
<
int
>
(
nrow
),
static_cast
<
int
>
(
ncol
),
eps
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
GPU
(
Error_t
)
DispatchLaunchRmsNormWarpImplPadding
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
==
max_cols_per_thread
*
thread_group_width
)
{
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return
LaunchRmsNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
max_cols_per_thread
,
thread_group_width
,
rows_per_access
,
false
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
else
{
return
LaunchRmsNormWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
,
true
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchLaunchRmsNormWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
<=
0
)
{
return
GPU
(
ErrorInvalidValue
);
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} else { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (ncol <= (max_col)*kWarpSize) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, load, store, nrow, \
ncol, eps, inv_rms); \
}
DEFINE_ONE_ELIF
(
2
,
1
)
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
GPU
(
ErrorInvalidValue
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
2
,
GPU
(
Error_t
)
>::
type
DispatchLaunchRmsNormWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
<=
0
)
{
return
GPU
(
ErrorInvalidValue
);
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} else { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((ncol <= (max_col)*kWarpSize) && (ncol > (min_col)*kWarpSize)) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, load, store, nrow, \
ncol, eps, inv_rms); \
}
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
GPU
(
ErrorInvalidValue
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormWarpImplPackSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
%
2
==
0
&&
layer_norm
::
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
DispatchLaunchRmsNormWarpImplCols
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
else
{
return
DispatchLaunchRmsNormWarpImplCols
<
LOAD
,
STORE
,
ComputeType
,
1
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
return
DispatchLaunchRmsNormWarpImplPackSize
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
__global__
void
RmsNormBlockSMemImpl
(
LOAD
load
,
STORE
store
,
const
int
nrow
,
const
int
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
ComputeType
*>
(
shared_buf
);
assert
(
ncol
%
pack_size
==
0
);
const
int
num_packs
=
ncol
/
pack_size
;
for
(
int
row
=
blockIdx
.
x
;
row
<
nrow
;
row
+=
gridDim
.
x
)
{
ComputeType
thread_square_sum
=
0
;
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
block_size
)
{
ComputeType
pack
[
pack_size
];
const
int
col
=
pack_i
*
pack_size
;
load
.
template
load
<
pack_size
>(
pack
,
row
,
col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
buf
[
pack_i
*
pack_size
+
pack_j
]
=
pack
[
pack_j
];
thread_square_sum
+=
pack
[
pack_j
]
*
pack
[
pack_j
];
}
}
ComputeType
row_square_sum
=
layer_norm
::
BlockAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
block_size
>
(
thread_square_sum
);
ComputeType
row_square_mean
=
layer_norm
::
Div
(
row_square_sum
,
static_cast
<
ComputeType
>
(
ncol
));
ComputeType
row_inv_rms
=
layer_norm
::
Rsqrt
(
row_square_mean
+
static_cast
<
ComputeType
>
(
eps
));
if
(
threadIdx
.
x
==
0
)
{
inv_rms
[
row
]
=
row_inv_rms
;
}
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
block_size
)
{
ComputeType
pack
[
pack_size
];
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
pack
[
pack_j
]
=
buf
[
pack_i
*
pack_size
+
pack_j
]
*
row_inv_rms
;
}
const
int
col
=
pack_i
*
pack_size
;
store
.
template
store
<
pack_size
>(
pack
,
row
,
col
);
}
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
GPU
(
Error_t
)
LaunchRmsNormBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
size_t
smem_size
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
smem_size
,
nrow
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
RmsNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
smem_size
,
stream
>>>
(
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
GPU
(
Error_t
)
TryDispatchLaunchRmsNormBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
,
bool
*
success
)
{
constexpr
int
block_size_conf_1
=
128
;
constexpr
int
block_size_conf_2
=
256
;
constexpr
int
block_size_conf_3
=
512
;
constexpr
int
block_size_conf_4
=
1024
;
const
size_t
smem_size
=
ncol
*
sizeof
(
ComputeType
);
int
max_active_blocks
=
0
;
int
num_blocks
=
0
;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&num_blocks, RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>, \
block_size_conf, smem_size); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks == 0) { \
if (num_blocks <= max_active_blocks) { \
*success = false; \
return GPU(Success); \
} \
max_active_blocks = num_blocks; \
} else { \
if (num_blocks == max_active_blocks) { \
*success = true; \
return LaunchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>( \
stream, load, store, smem_size, nrow, ncol, eps, inv_rms); \
} \
} \
}
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_1
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_4
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_3
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_2
)
#undef SELECT_BLOCK_SIZE_CONF
*
success
=
true
;
return
LaunchRmsNormBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
(
stream
,
load
,
store
,
smem_size
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
TryDispatchLaunchRmsNormBlockSMemImplPackSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
,
bool
*
success
)
{
if
(
ncol
%
4
==
0
&&
layer_norm
::
CanPackAs
<
LOAD
>
(
load
,
4
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
4
))
{
return
TryDispatchLaunchRmsNormBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
4
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
,
success
);
}
else
if
(
ncol
%
2
==
0
&&
layer_norm
::
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
TryDispatchLaunchRmsNormBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
,
success
);
}
else
{
return
TryDispatchLaunchRmsNormBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
1
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
,
success
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
TryDispatchLaunchRmsNormBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
,
bool
*
success
)
{
return
TryDispatchLaunchRmsNormBlockSMemImplPackSize
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
,
success
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
__global__
void
RmsNormBlockUncachedImpl
(
LOAD
load
,
STORE
store
,
const
int
nrow
,
const
int
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
assert
(
ncol
%
pack_size
==
0
);
const
int
num_packs
=
ncol
/
pack_size
;
for
(
int
row
=
blockIdx
.
x
;
row
<
nrow
;
row
+=
gridDim
.
x
)
{
ComputeType
thread_square_sum
=
0
;
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
block_size
)
{
ComputeType
pack
[
pack_size
];
const
int
col
=
pack_i
*
pack_size
;
load
.
template
load
<
pack_size
>(
pack
,
row
,
col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
thread_square_sum
+=
pack
[
pack_j
]
*
pack
[
pack_j
];
}
}
ComputeType
row_square_sum
=
layer_norm
::
BlockAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
block_size
>
(
thread_square_sum
);
ComputeType
row_square_mean
=
layer_norm
::
Div
(
row_square_sum
,
static_cast
<
ComputeType
>
(
ncol
));
ComputeType
row_inv_rms
=
layer_norm
::
Rsqrt
(
row_square_mean
+
static_cast
<
ComputeType
>
(
eps
));
if
(
threadIdx
.
x
==
0
)
{
inv_rms
[
row
]
=
row_inv_rms
;
}
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
block_size
)
{
ComputeType
pack
[
pack_size
];
const
int
col
=
pack_i
*
pack_size
;
load
.
template
load
<
pack_size
>(
pack
,
row
,
col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
pack
[
pack_j
]
=
pack
[
pack_j
]
*
row_inv_rms
;
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
col
);
}
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
GPU
(
Error_t
)
LaunchRmsNormBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
0
,
nrow
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
RmsNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormBlockUncachedImplPackSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
%
4
==
0
&&
layer_norm
::
CanPackAs
<
LOAD
>
(
load
,
4
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
4
))
{
return
LaunchRmsNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
4
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
else
if
(
ncol
%
2
==
0
&&
layer_norm
::
CanPackAs
<
LOAD
>
(
load
,
2
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
LaunchRmsNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
2
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
else
{
return
LaunchRmsNormBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
1
>
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
return
DispatchLaunchRmsNormBlockUncachedImplPackSize
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
LaunchRmsNorm
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
if
(
ncol
<=
1024
)
{
return
DispatchLaunchRmsNormWarpImpl
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
else
{
bool
dispatch_smem_impl_success
=
false
;
{
GPU
(
Error_t
)
err
=
TryDispatchLaunchRmsNormBlockSMemImpl
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
,
&
dispatch_smem_impl_success
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchLaunchRmsNormBlockUncachedImpl
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
return
GPU
(
Success
);
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
LaunchRmsNorm
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
double
eps
,
ComputeType
*
inv_rms
)
{
return
DispatchLaunchRmsNormBlockUncachedImpl
(
stream
,
load
,
store
,
nrow
,
ncol
,
eps
,
inv_rms
);
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
__global__
void
RmsNormGradWarpImpl
(
const
int
nrow
,
const
int
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
static_assert
(
max_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
min_cols_per_thread
%
pack_size
==
0
,
""
);
static_assert
(
thread_group_width
<=
kWarpSize
,
""
);
static_assert
(
kWarpSize
%
thread_group_width
==
0
,
""
);
assert
(
ncol
<=
max_cols_per_thread
*
thread_group_width
);
constexpr
int
max_packs
=
max_cols_per_thread
/
pack_size
;
constexpr
int
min_packs
=
min_cols_per_thread
/
pack_size
;
ComputeType
normalized_buf
[
rows_per_access
][
max_cols_per_thread
];
ComputeType
dy_buf
[
rows_per_access
][
max_cols_per_thread
];
const
int
global_thread_group_id
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
num_global_thread_group
=
gridDim
.
x
*
blockDim
.
y
;
for
(
int
row_i
=
global_thread_group_id
;
row_i
<
nrow
;
row_i
+=
num_global_thread_group
)
{
ComputeType
sum_stats
[
rows_per_access
];
ComputeType
inv_rms_buf
[
rows_per_access
];
#pragma unroll
for
(
int
row_j
=
0
;
row_j
<
rows_per_access
;
++
row_j
)
{
const
int
global_row
=
row_i
*
rows_per_access
+
row_j
;
sum_stats
[
row_j
]
=
0
;
inv_rms_buf
[
row_j
]
=
inv_rms
[
global_row
];
ComputeType
*
row_normalized_buf
=
normalized_buf
[
row_j
];
ComputeType
*
row_dy_buf
=
dy_buf
[
row_j
];
#pragma unroll
for
(
int
pack_i
=
0
;
pack_i
<
min_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
global_col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
load_x
.
template
load
<
pack_size
>(
row_normalized_buf
+
pack_offset
,
global_row
,
global_col
);
load_dy
.
template
load
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row
,
global_col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
row_normalized_buf
[
col
]
=
row_normalized_buf
[
col
]
*
inv_rms_buf
[
row_j
];
sum_stats
[
row_j
]
+=
row_dy_buf
[
col
]
*
row_normalized_buf
[
col
];
}
}
#pragma unroll
for
(
int
pack_i
=
min_packs
;
pack_i
<
max_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
global_col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
if
(
global_col
<
ncol
)
{
load_x
.
template
load
<
pack_size
>(
row_normalized_buf
+
pack_offset
,
global_row
,
global_col
);
load_dy
.
template
load
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row
,
global_col
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
row_normalized_buf
[
col
]
=
row_normalized_buf
[
col
]
*
inv_rms_buf
[
row_j
];
sum_stats
[
row_j
]
+=
row_dy_buf
[
col
]
*
row_normalized_buf
[
col
];
}
}
}
}
ComputeType
warp_sum_stats
[
rows_per_access
];
#pragma unroll
for
(
int
row_j
=
0
;
row_j
<
rows_per_access
;
++
row_j
)
{
warp_sum_stats
[
row_j
]
=
layer_norm
::
WarpAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
thread_group_width
>
(
sum_stats
[
row_j
]);
}
#pragma unroll
for
(
int
row_j
=
0
;
row_j
<
rows_per_access
;
++
row_j
)
{
const
int
global_row
=
row_i
*
rows_per_access
+
row_j
;
ComputeType
*
row_normalized_buf
=
normalized_buf
[
row_j
];
ComputeType
*
row_dy_buf
=
dy_buf
[
row_j
];
#pragma unroll
for
(
int
pack_i
=
0
;
pack_i
<
min_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
global_col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
const
ComputeType
norm_val
=
layer_norm
::
Div
(
row_normalized_buf
[
col
],
static_cast
<
ComputeType
>
(
ncol
));
row_dy_buf
[
col
]
=
(
row_dy_buf
[
col
]
-
norm_val
*
warp_sum_stats
[
row_j
])
*
inv_rms_buf
[
row_j
];
}
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row
,
global_col
);
}
#pragma unroll
for
(
int
pack_i
=
min_packs
;
pack_i
<
max_packs
;
++
pack_i
)
{
const
int
pack_offset
=
pack_i
*
pack_size
;
const
int
global_col
=
(
pack_i
*
thread_group_width
+
threadIdx
.
x
)
*
pack_size
;
if
(
global_col
<
ncol
)
{
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
const
ComputeType
norm_val
=
layer_norm
::
Div
(
row_normalized_buf
[
col
],
static_cast
<
ComputeType
>
(
ncol
));
row_dy_buf
[
col
]
=
(
row_dy_buf
[
col
]
-
norm_val
*
warp_sum_stats
[
row_j
])
*
inv_rms_buf
[
row_j
];
}
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_offset
,
global_row
,
global_col
);
}
}
}
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
max_cols_per_thread
,
int
min_cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
>
GPU
(
Error_t
)
LaunchRmsNormGradWarpImpl
(
GPU
(
Stream_t
)
stream
,
const
int
nrow
,
const
int
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
const
int64_t
num_blocks
=
(
nrow
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormGradWarpImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
>
,
block_size
,
0
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
RmsNormGradWarpImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
max_cols_per_thread
,
min_cols_per_thread
,
thread_group_width
,
rows_per_access
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchLaunchRmsNormGradWarpImplCols
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
if
(
ncol
<=
0
)
{
return
GPU
(
ErrorInvalidValue
);
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} else { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} \
}
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (ncol <= (max_col)*kWarpSize) { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, nrow, ncol, load_x, load_dy, \
store, inv_rms); \
}
DEFINE_ONE_ELIF
(
2
,
1
)
DEFINE_ONE_ELIF
(
4
,
2
)
DEFINE_ONE_ELIF
(
8
,
4
)
DEFINE_ONE_ELIF
(
12
,
8
)
DEFINE_ONE_ELIF
(
16
,
12
)
DEFINE_ONE_ELIF
(
20
,
16
)
DEFINE_ONE_ELIF
(
24
,
20
)
DEFINE_ONE_ELIF
(
28
,
24
)
DEFINE_ONE_ELIF
(
32
,
28
)
#undef DEFINE_ONE_ELIF
else
{
return
GPU
(
ErrorInvalidValue
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormGradWarpImplPackSize
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
return
DispatchLaunchRmsNormGradWarpImplCols
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
__global__
void
RmsNormGradBlockSMemImpl
(
const
int
nrow
,
const
int
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
dyn_smem
[];
// dynamic shared memory for caching x and dy
auto
*
normalized_buf
=
reinterpret_cast
<
ComputeType
*>
(
dyn_smem
);
auto
*
dy_buf
=
normalized_buf
+
ncol
;
assert
(
ncol
%
pack_size
==
0
);
const
int
num_packs
=
ncol
/
pack_size
;
for
(
int
row
=
blockIdx
.
x
;
row
<
nrow
;
row
+=
gridDim
.
x
)
{
ComputeType
sum_stats
=
0
;
const
ComputeType
inv_rms_val
=
inv_rms
[
row
];
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
blockDim
.
x
)
{
ComputeType
x_pack
[
pack_size
];
ComputeType
dy_pack
[
pack_size
];
const
int
pack_offset
=
pack_i
*
pack_size
;
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_offset
);
load_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
normalized_buf
[
col
]
=
x_pack
[
pack_j
]
*
inv_rms_val
;
dy_buf
[
col
]
=
dy_pack
[
pack_j
];
sum_stats
+=
dy_buf
[
col
]
*
normalized_buf
[
col
];
}
}
const
ComputeType
row_sum_stats
=
layer_norm
::
BlockAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
block_size
>
(
sum_stats
);
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
blockDim
.
x
)
{
ComputeType
pack
[
pack_size
];
const
int
pack_offset
=
pack_i
*
pack_size
;
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
int
col
=
pack_offset
+
pack_j
;
const
ComputeType
norm_val
=
layer_norm
::
Div
(
normalized_buf
[
col
],
static_cast
<
ComputeType
>
(
ncol
));
pack
[
pack_j
]
=
(
dy_buf
[
col
]
-
norm_val
*
row_sum_stats
)
*
inv_rms_val
;
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_offset
);
}
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
GPU
(
Error_t
)
LaunchRmsNormGradBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
const
size_t
smem_size
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
smem_size
,
nrow
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
RmsNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
smem_size
,
stream
>>>
(
static_cast
<
int
>
(
nrow
),
static_cast
<
int
>
(
ncol
),
load_x
,
load_dy
,
store
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
GPU
(
Error_t
)
TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
,
bool
*
success
)
{
constexpr
int
block_size_conf_1
=
128
;
constexpr
int
block_size_conf_2
=
256
;
constexpr
int
block_size_conf_3
=
512
;
constexpr
int
block_size_conf_4
=
1024
;
const
size_t
smem_size
=
ncol
*
sizeof
(
ComputeType
)
*
2
;
// ncol * 2 for caching x and dy both
int
max_active_blocks
=
0
;
int
num_blocks
=
0
;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&num_blocks, \
RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf>, \
block_size_conf, smem_size); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks == 0) { \
if (num_blocks <= max_active_blocks) { \
*success = false; \
return GPU(Success); \
} \
max_active_blocks = num_blocks; \
} else { \
if (num_blocks == max_active_blocks) { \
*success = true; \
return LaunchRmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>(stream, nrow, ncol, smem_size, \
load_x, load_dy, store, inv_rms); \
} \
} \
}
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_1
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_4
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_3
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_2
)
#undef SELECT_BLOCK_SIZE_CONF
*
success
=
true
;
return
LaunchRmsNormGradBlockSMemImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
>
(
stream
,
nrow
,
ncol
,
smem_size
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
TryDispatchLaunchRmsNormGradBlockSMemImplPackSize
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
,
bool
*
success
)
{
if
(
ncol
%
2
==
0
&&
layer_norm
::
CanPackAs
<
LOAD_X
>
(
load_x
,
2
)
&&
layer_norm
::
CanPackAs
<
LOAD_DY
>
(
load_dy
,
2
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
2
))
{
return
TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
2
>
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
,
success
);
}
else
{
return
TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
,
success
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
__global__
void
RmsNormGradBlockUncachedImpl
(
const
int
nrow
,
const
int
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
assert
(
ncol
%
pack_size
==
0
);
const
int
num_packs
=
ncol
/
pack_size
;
for
(
int
row
=
blockIdx
.
x
;
row
<
nrow
;
row
+=
gridDim
.
x
)
{
const
ComputeType
inv_rms_val
=
inv_rms
[
row
];
ComputeType
sum_stats
=
0
;
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
blockDim
.
x
)
{
ComputeType
x_pack
[
pack_size
];
ComputeType
dy_pack
[
pack_size
];
const
int
pack_offset
=
pack_i
*
pack_size
;
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_offset
);
load_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
sum_stats
+=
dy_pack
[
pack_j
]
*
x_pack
[
pack_j
]
*
inv_rms_val
;
}
}
const
ComputeType
row_sum_stats
=
layer_norm
::
BlockAllReduce
<
layer_norm
::
SumOp
,
ComputeType
,
block_size
>
(
sum_stats
);
for
(
int
pack_i
=
threadIdx
.
x
;
pack_i
<
num_packs
;
pack_i
+=
blockDim
.
x
)
{
ComputeType
x_pack
[
pack_size
];
ComputeType
dy_pack
[
pack_size
];
const
int
pack_offset
=
pack_i
*
pack_size
;
load_x
.
template
load
<
pack_size
>(
x_pack
,
row
,
pack_offset
);
load_dy
.
template
load
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
#pragma unroll
for
(
int
pack_j
=
0
;
pack_j
<
pack_size
;
++
pack_j
)
{
const
ComputeType
norm_val
=
layer_norm
::
Div
(
x_pack
[
pack_j
]
*
inv_rms_val
,
static_cast
<
ComputeType
>
(
ncol
));
dy_pack
[
pack_j
]
=
(
dy_pack
[
pack_j
]
-
norm_val
*
row_sum_stats
)
*
inv_rms_val
;
}
store
.
template
store
<
pack_size
>(
dy_pack
,
row
,
pack_offset
);
}
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
>
GPU
(
Error_t
)
LaunchRmsNormGradBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
,
block_size
,
0
,
nrow
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
RmsNormGradBlockUncachedImpl
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
return
GPU
(
PeekAtLastError
)();
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
>
GPU
(
Error_t
)
DispatchLaunchRmsNormGradBlockUncachedImplBlockSize
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
constexpr
int
block_size_conf_1
=
128
;
constexpr
int
block_size_conf_2
=
256
;
constexpr
int
block_size_conf_3
=
512
;
constexpr
int
block_size_conf_4
=
1024
;
int
max_active_blocks
=
0
;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&max_active_blocks, \
RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>, \
block_size_conf, 0); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks > 0) { \
return LaunchRmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} \
}
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_4
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_3
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_2
)
SELECT_BLOCK_SIZE_CONF
(
block_size_conf_1
)
#undef SELECT_BLOCK_SIZE_CONF
return
GPU
(
ErrorInvalidValue
);
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
GPU
(
Error_t
)
DispatchLaunchRmsNormGradBlockUncachedImplPackSize
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
if
(
ncol
%
2
==
0
&&
layer_norm
::
CanPackAs
<
LOAD_X
>
(
load_x
,
2
)
&&
layer_norm
::
CanPackAs
<
LOAD_DY
>
(
load_dy
,
2
)
&&
layer_norm
::
CanPackAs
<
STORE
>
(
store
,
2
)
&&
ncol
>
kWarpSize
)
{
return
DispatchLaunchRmsNormGradBlockUncachedImplBlockSize
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
2
>
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
else
{
return
DispatchLaunchRmsNormGradBlockUncachedImplBlockSize
<
LOAD_X
,
LOAD_DY
,
STORE
,
ComputeType
,
1
>
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
LaunchRmsNormGrad
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
if
(
ncol
<=
1024
)
{
return
DispatchLaunchRmsNormGradWarpImplPackSize
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
else
{
bool
dispatch_smem_impl_success
=
false
;
{
GPU
(
Error_t
)
err
=
TryDispatchLaunchRmsNormGradBlockSMemImplPackSize
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
,
&
dispatch_smem_impl_success
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchLaunchRmsNormGradBlockUncachedImplPackSize
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
return
GPU
(
Success
);
}
}
template
<
typename
LOAD_X
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
LaunchRmsNormGrad
(
GPU
(
Stream_t
)
stream
,
const
int64_t
nrow
,
const
int64_t
ncol
,
LOAD_X
load_x
,
LOAD_DY
load_dy
,
STORE
store
,
const
ComputeType
*
inv_rms
)
{
return
DispatchLaunchRmsNormGradBlockUncachedImplPackSize
(
stream
,
nrow
,
ncol
,
load_x
,
load_dy
,
store
,
inv_rms
);
}
template
<
int
nproc_per_thread
,
typename
T
,
typename
ComputeType
>
__global__
void
RmsNormParamGrad
(
int
nrow
,
int
ncol
,
const
T
*
__restrict__
dy
,
const
T
*
__restrict__
x
,
const
ComputeType
*
__restrict__
inv_rms
,
T
*
__restrict__
b_weight_grad
)
{
__shared__
ComputeType
dweight
[
kWarpSize
][
kWarpSize
+
1
];
ComputeType
dweight_sum
[
nproc_per_thread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
nproc_per_thread
;
++
i
)
{
dweight_sum
[
i
]
=
0
;
}
const
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
col
<
ncol
)
{
// a wave for one traverse (when nrow > warp_size * grad_dim_y)
for
(
int
j
=
blockIdx
.
y
*
kWarpSize
+
threadIdx
.
y
;
j
<
nrow
;
j
+=
kWarpSize
*
gridDim
.
y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nproc_per_thread
;
++
i
)
{
int
row
=
j
+
i
*
blockDim
.
y
;
if
(
row
<
nrow
)
{
int
offset
=
row
*
ncol
+
col
;
const
ComputeType
dy_val
=
static_cast
<
ComputeType
>
(
dy
[
offset
]);
const
ComputeType
x_val
=
static_cast
<
ComputeType
>
(
x
[
offset
]);
const
ComputeType
inv_rms_val
=
inv_rms
[
row
];
// collect dx from waves
dweight_sum
[
i
]
+=
dy_val
*
x_val
*
inv_rms_val
;
}
}
}
}
// broadcast sum to the nproc_per_thread number rows
// each warp process the nproc_per_thread number rows of smem
#pragma unroll
for
(
int
i
=
0
;
i
<
nproc_per_thread
;
++
i
)
{
dweight
[
i
*
blockDim
.
y
+
threadIdx
.
y
][
threadIdx
.
x
]
=
dweight_sum
[
i
];
}
__syncthreads
();
// transpose access for leveraging warp to reduce rows in a block
#pragma unroll
for
(
int
i
=
0
;
i
<
nproc_per_thread
;
++
i
)
{
// the first col of block threads is for storing the reduced sum of rows,
// and each first col thread is writing the nproc_per_thread number cols of output
const
int
row_in_block
=
threadIdx
.
y
+
i
*
blockDim
.
y
;
const
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
row_in_block
;
if
(
col
<
ncol
)
{
// each warp process a col in which reduce sum all rows
ComputeType
dweight_val
=
dweight
[
threadIdx
.
x
][
row_in_block
];
ComputeType
global_dweight
=
WarpReduceSum
<
ComputeType
>
(
dweight_val
);
if
(
threadIdx
.
x
==
0
)
{
const
int
offset
=
blockIdx
.
y
*
ncol
+
col
;
b_weight_grad
[
offset
]
=
global_dweight
;
}
}
}
}
template
<
int
nproc_per_thread
,
typename
T
>
GPU
(
Error_t
)
GetGrid2Dim
(
const
int64_t
nrow
,
const
int64_t
ncol
,
int
block_dim_x
,
int
block_dim_y
,
int
*
grid_dim_x
,
int
*
grid_dim_y
)
{
const
int
tile_size
=
block_dim_x
;
if
(
nproc_per_thread
*
block_dim_y
!=
tile_size
)
{
return
GPU
(
ErrorInvalidValue
);
}
*
grid_dim_x
=
(
ncol
+
tile_size
-
1
)
/
tile_size
;
const
int
num_blocks_y
=
(
nrow
+
tile_size
-
1
)
/
tile_size
;
using
ComputeType
=
typename
layer_norm
::
DefaultComputeType
<
T
>::
type
;
GPU
(
Error_t
)
err
=
layer_norm
::
GetNumBlocks
(
RmsNormParamGrad
<
nproc_per_thread
,
T
,
ComputeType
>
,
block_dim_x
*
block_dim_y
,
/*dynamic_smem_size*/
0
,
num_blocks_y
,
/*waves*/
1
,
grid_dim_y
);
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
return
GPU
(
Success
);
}
}
// namespace rms_norm
}
// namespace cuda
}
// namespace oneflow
#endif // ONEFLOW_CORE_CUDA_RMS_NORM_H_
oneflow/core/cuda/softmax.cuh
View file @
a715222c
...
...
@@ -17,10 +17,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_SOFTMAX_H_
#define ONEFLOW_CORE_CUDA_SOFTMAX_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cuda.h>
#include <cub/cub.cuh>
#include <math_constants.h>
#endif
#include <assert.h>
#include <cuda.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
...
...
@@ -32,7 +37,11 @@ namespace cuda {
namespace
softmax
{
#ifdef WITH_ROCM
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kWarpSize
=
32
;
#endif
template
<
typename
T
>
struct
SumOp
{
...
...
@@ -47,14 +56,22 @@ struct MaxOp {
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
thread_group_width
=
kWarpSize
>
__inline__
__device__
T
WarpAllReduce
(
T
val
)
{
for
(
int
mask
=
thread_group_width
/
2
;
mask
>
0
;
mask
/=
2
)
{
#ifdef WITH_ROCM
val
=
ReductionOp
<
T
>
()(
val
,
__shfl_xor
(
val
,
mask
));
#else
val
=
ReductionOp
<
T
>
()(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
));
#endif
}
return
val
;
}
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
block_size
>
__inline__
__device__
T
BlockAllReduce
(
T
val
)
{
#ifdef WITH_ROCM
typedef
hipcub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
#else
typedef
cub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
#endif
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
result_broadcast
;
T
result
=
BlockReduce
(
temp_storage
).
Reduce
(
val
,
ReductionOp
<
T
>
());
...
...
@@ -68,12 +85,20 @@ __inline__ __device__ T Inf();
template
<
>
__inline__
__device__
float
Inf
<
float
>
()
{
#ifdef WITH_ROCM
return
__int_as_float
(
0x7f800000U
);
#else
return
CUDART_INF_F
;
#endif
}
template
<
>
__inline__
__device__
double
Inf
<
double
>
()
{
#ifdef WITH_ROCM
return
__longlong_as_double
(
0x7ff0000000000000ULL
);
#else
return
CUDART_INF
;
#endif
}
template
<
typename
T
>
...
...
@@ -126,26 +151,26 @@ __inline__ __device__ double Log<double>(double x) {
return
log
(
x
);
}
inline
cuda
Error_t
GetNumBlocks
(
int64_t
block_size
,
int64_t
max_blocks
,
int64_t
waves
,
inline
GPU
(
Error_t
)
GetNumBlocks
(
int64_t
block_size
,
int64_t
max_blocks
,
int64_t
waves
,
int
*
num_blocks
)
{
int
dev
;
{
cuda
Error_t
err
=
cuda
GetDevice
(
&
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
GetDevice
)
(
&
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
sm_count
;
{
cuda
Error_t
err
=
cuda
DeviceGetAttribute
(
&
sm_count
,
cudaDevAttr
MultiProcessorCount
,
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)
(
&
sm_count
,
GPU
MultiProcessorCount
,
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
int
tpm
;
{
cuda
Error_t
err
=
cuda
DeviceGetAttribute
(
&
tpm
,
cudaDevAttr
MaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GPU
(
DeviceGetAttribute
)
(
&
tpm
,
GPU
MaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
(
max_blocks
,
sm_count
*
tpm
/
block_size
*
waves
));
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
T
>
...
...
@@ -272,7 +297,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
row_buf
[
i
]
-=
warp_max
[
row_id
];
thread_sum
[
row_id
]
+=
Exp
(
row_buf
[
i
]);
}
else
{
__trap
();
TRAP
();
}
}
}
...
...
@@ -291,7 +316,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
row_buf
[
i
]
-=
Log
(
warp_sum
[
row_id
]);
}
else
{
__trap
();
TRAP
();
}
}
#pragma unroll
...
...
@@ -307,7 +332,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxWarpImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
LaunchSoftmaxWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
...
...
@@ -318,18 +343,18 @@ inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE s
(
rows
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
,
algorithm
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
load
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxWarpImplPadding
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
DispatchSoftmaxWarpImplPadding
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
==
cols_per_thread
*
thread_group_width
)
{
return
LaunchSoftmaxWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
...
...
@@ -343,9 +368,9 @@ inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
typename
std
::
enable_if
<
pack_size
==
1
,
cuda
Error_t
>::
type
DispatchSoftmaxWarpImplCols
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
...
...
@@ -403,14 +428,14 @@ typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpIm
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
typename
std
::
enable_if
<
pack_size
==
2
,
cuda
Error_t
>::
type
DispatchSoftmaxWarpImplCols
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
typename
std
::
enable_if
<
pack_size
==
2
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
...
...
@@ -452,13 +477,13 @@ typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpIm
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
DispatchSoftmaxWarpImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
)
{
return
DispatchSoftmaxWarpImplCols
<
LOAD
,
STORE
,
ComputeType
,
2
,
algorithm
>
(
stream
,
load
,
...
...
@@ -471,7 +496,7 @@ struct DispatchSoftmaxWarpImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxWarpImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
DispatchSoftmaxWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxWarpImplPackSize
<
LOAD
,
STORE
,
ComputeType
,
algorithm
>
()(
stream
,
load
,
store
,
rows
,
cols
);
...
...
@@ -520,7 +545,7 @@ __global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
pack
[
i
]
=
buf
[
i
*
num_packs
+
pack_id
]
-
Log
(
row_sum
);
}
else
{
__trap
();
TRAP
();
}
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
...
...
@@ -530,21 +555,21 @@ __global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
int
smem
,
inline
GPU
(
Error_t
)
LaunchSoftmaxBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
int
smem
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
,
algorithm
>
<<<
grid_dim_x
,
block_size
,
smem
,
stream
>>>
(
load
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
TryDispatchSoftmaxBlockSMemImplBlockSize
(
cuda
Stream_t
stream
,
LOAD
load
,
inline
GPU
(
Error_t
)
TryDispatchSoftmaxBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
constexpr
int
block_size_conf_1
=
128
;
...
...
@@ -554,23 +579,23 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
const
size_t
smem
=
cols
*
sizeof
(
ComputeType
);
int
max_active_blocks_conf_1
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_1
,
SoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
,
algorithm
>
,
block_size_conf_1
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_1
<=
0
)
{
*
success
=
false
;
return
cuda
Success
;
return
GPU
(
Success
)
;
}
int
max_active_blocks_conf_4
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_4
,
SoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
,
algorithm
>
,
block_size_conf_4
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -579,11 +604,11 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
}
int
max_active_blocks_conf_3
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_3
,
SoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
,
algorithm
>
,
block_size_conf_3
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -592,11 +617,11 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
}
int
max_active_blocks_conf_2
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_2
,
SoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
,
algorithm
>
,
block_size_conf_2
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -610,7 +635,7 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
TryDispatchSoftmaxBlockSMemImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
if
(
cols
%
2
==
0
)
{
return
TryDispatchSoftmaxBlockSMemImplBlockSize
<
LOAD
,
STORE
,
ComputeType
,
2
,
algorithm
>
(
...
...
@@ -623,7 +648,7 @@ struct TryDispatchSoftmaxBlockSMemImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
TryDispatchSoftmaxBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
TryDispatchSoftmaxBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
return
TryDispatchSoftmaxBlockSMemImplPackSize
<
LOAD
,
STORE
,
ComputeType
,
algorithm
>
()(
...
...
@@ -664,7 +689,7 @@ __global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t r
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
pack
[
i
]
=
(
pack
[
i
]
-
row_max
)
-
Log
(
row_sum
);
}
else
{
__trap
();
TRAP
();
}
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
...
...
@@ -673,23 +698,23 @@ __global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t r
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
LaunchSoftmaxBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
pack_size
,
block_size
,
algorithm
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
load
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
DispatchSoftmaxBlockUncachedImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
)
{
return
LaunchSoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
2
,
algorithm
>
(
...
...
@@ -702,15 +727,15 @@ struct DispatchSoftmaxBlockUncachedImplPackSize {
};
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
inline
GPU
(
Error_t
)
DispatchSoftmaxBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxBlockUncachedImplPackSize
<
LOAD
,
STORE
,
ComputeType
,
algorithm
>
()(
stream
,
load
,
store
,
rows
,
cols
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchSoftmax
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchSoftmax
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<
1024
)
{
return
DispatchSoftmaxWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
...
...
@@ -718,30 +743,30 @@ DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
TryDispatchSoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchSoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchSoftmax
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchSoftmax
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
);
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLogSoftmax
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLogSoftmax
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
1024
)
{
return
DispatchSoftmaxWarpImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
...
...
@@ -749,22 +774,22 @@ DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t ro
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
TryDispatchSoftmaxBlockSMemImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchSoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLogSoftmax
(
cuda
Stream_t
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLogSoftmax
(
GPU
(
Stream_t
)
stream
,
LOAD
load
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxBlockUncachedImpl
<
LOAD
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load
,
store
,
rows
,
cols
);
...
...
@@ -807,7 +832,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
thread_sum
[
row_id
]
+=
row_dy_buf
[
pack_offset
+
i
];
}
else
{
__trap
();
TRAP
();
}
}
}
...
...
@@ -834,7 +859,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
row_dy_buf
[
pack_offset
+
i
]
-=
Exp
(
row_y_buf
[
pack_offset
+
i
])
*
warp_sum
[
row_id
];
}
else
{
__trap
();
TRAP
();
}
}
store
.
template
store
<
pack_size
>(
row_dy_buf
+
pack_offset
,
row
+
row_id
,
col
);
...
...
@@ -847,7 +872,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
bool
padding
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxGradWarpImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
inline
GPU
(
Error_t
)
LaunchSoftmaxGradWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
...
...
@@ -858,18 +883,18 @@ inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y,
(
rows
/
rows_per_access
+
thread_groups_per_block
-
1
)
/
thread_groups_per_block
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
num_blocks
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxGradWarpImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
cols_per_thread
,
thread_group_width
,
rows_per_access
,
padding
,
algorithm
>
<<<
grid_dim_x
,
block_dim
,
0
,
stream
>>>
(
load_y
,
load_dy
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
cols_per_thread
,
int
thread_group_width
,
int
rows_per_access
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxGradWarpImplPadding
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
DispatchSoftmaxGradWarpImplPadding
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
==
cols_per_thread
*
thread_group_width
)
{
...
...
@@ -885,10 +910,10 @@ inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
typename
std
::
enable_if
<
pack_size
==
1
,
cuda
Error_t
>::
type
DispatchSoftmaxGradWarpImplCols
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
typename
std
::
enable_if
<
pack_size
==
1
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxGradWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
...
...
@@ -947,16 +972,16 @@ typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWa
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
typename
std
::
enable_if
<
pack_size
==
2
,
cuda
Error_t
>::
type
DispatchSoftmaxGradWarpImplCols
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
typename
std
::
enable_if
<
pack_size
==
2
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxGradWarpImplCols
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
0
)
{
return
cuda
ErrorInvalidValue
;
}
if
(
cols
<=
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
...
...
@@ -999,14 +1024,14 @@ typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWa
DEFINE_ONE_ELIF
(
32
)
#undef DEFINE_ONE_ELIF
else
{
return
cuda
ErrorInvalidValue
;
return
GPU
(
ErrorInvalidValue
)
;
}
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
DispatchSoftmaxGradWarpImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
)
{
return
DispatchSoftmaxGradWarpImplCols
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
2
,
algorithm
>
(
...
...
@@ -1020,7 +1045,7 @@ struct DispatchSoftmaxGradWarpImplPackSize {
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxGradWarpImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
inline
GPU
(
Error_t
)
DispatchSoftmaxGradWarpImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxGradWarpImplPackSize
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
algorithm
>
()(
...
...
@@ -1053,7 +1078,7 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
thread_sum
+=
dy_pack
[
i
];
}
else
{
__trap
();
TRAP
();
}
}
}
...
...
@@ -1067,7 +1092,7 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
pack
[
i
]
=
dy_buf
[
i
*
num_packs
+
pack_id
]
-
Exp
(
y_buf
[
i
*
num_packs
+
pack_id
])
*
row_sum
;
}
else
{
__trap
();
TRAP
();
}
}
store
.
template
store
<
pack_size
>(
pack
,
row
,
pack_id
*
pack_size
);
...
...
@@ -1077,23 +1102,23 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
int
block_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxGradBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
LaunchSoftmaxGradBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
int
smem
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
,
algorithm
>
<<<
grid_dim_x
,
block_size
,
smem
,
stream
>>>
(
load_y
,
load_dy
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
TryDispatchSoftmaxGradBlockSMemImplBlockSize
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
TryDispatchSoftmaxGradBlockSMemImplBlockSize
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
...
...
@@ -1104,25 +1129,25 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
const
size_t
smem
=
cols
*
sizeof
(
ComputeType
)
*
2
;
int
max_active_blocks_conf_1
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_1
,
SoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_1
,
algorithm
>
,
block_size_conf_1
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_1
<=
0
)
{
*
success
=
false
;
return
cuda
Success
;
return
GPU
(
Success
)
;
}
int
max_active_blocks_conf_4
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_4
,
SoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_4
,
algorithm
>
,
block_size_conf_4
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_4
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -1132,12 +1157,12 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
}
int
max_active_blocks_conf_3
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_3
,
SoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_3
,
algorithm
>
,
block_size_conf_3
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_3
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -1147,12 +1172,12 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
}
int
max_active_blocks_conf_2
;
{
cuda
Error_t
err
=
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
GPU
(
Error_t
)
err
=
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks_conf_2
,
SoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size_conf_2
,
algorithm
>
,
block_size_conf_2
,
smem
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
max_active_blocks_conf_2
==
max_active_blocks_conf_1
)
{
*
success
=
true
;
...
...
@@ -1169,7 +1194,7 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
TryDispatchSoftmaxGradBlockSMemImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
if
(
cols
%
2
==
0
)
{
return
TryDispatchSoftmaxGradBlockSMemImplBlockSize
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
2
,
...
...
@@ -1185,7 +1210,7 @@ struct TryDispatchSoftmaxGradBlockSMemImplPackSize {
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
TryDispatchSoftmaxGradBlockSMemImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
TryDispatchSoftmaxGradBlockSMemImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
,
bool
*
success
)
{
...
...
@@ -1216,7 +1241,7 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
thread_sum
+=
dy_pack
[
i
];
}
else
{
__trap
();
TRAP
();
}
}
}
...
...
@@ -1233,7 +1258,7 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
dy_pack
[
i
]
-=
Exp
(
y_pack
[
i
])
*
row_sum
;
}
else
{
__trap
();
TRAP
();
}
}
store
.
template
store
<
pack_size
>(
dy_pack
,
row
,
pack_id
*
pack_size
);
...
...
@@ -1243,26 +1268,26 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
int
pack_size
,
Algorithm
algorithm
>
inline
cuda
Error_t
LaunchSoftmaxGradBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
LaunchSoftmaxGradBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
waves
=
32
;
int
grid_dim_x
;
{
cuda
Error_t
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
GPU
(
Error_t
)
err
=
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim_x
);
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
SoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
pack_size
,
block_size
,
algorithm
>
<<<
grid_dim_x
,
block_size
,
0
,
stream
>>>
(
load_y
,
load_dy
,
store
,
rows
,
cols
);
return
cuda
PeekAtLastError
();
return
GPU
(
PeekAtLastError
)
();
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
struct
DispatchSoftmaxGradBlockUncachedImplPackSize
{
cuda
Error_t
operator
()(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
GPU
(
Error_t
)
operator
()(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
%
2
==
0
&&
cols
>
kWarpSize
)
{
return
LaunchSoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
2
,
algorithm
>
(
...
...
@@ -1276,7 +1301,7 @@ struct DispatchSoftmaxGradBlockUncachedImplPackSize {
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
,
Algorithm
algorithm
>
inline
cuda
Error_t
DispatchSoftmaxGradBlockUncachedImpl
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
inline
GPU
(
Error_t
)
DispatchSoftmaxGradBlockUncachedImpl
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxGradBlockUncachedImplPackSize
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
...
...
@@ -1285,8 +1310,8 @@ inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOA
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchSoftmaxGrad
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
1024
)
{
return
DispatchSoftmaxGradWarpImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
...
...
@@ -1294,23 +1319,23 @@ DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE s
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
TryDispatchSoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
GPU
(
Error_t
)
err
=
TryDispatchSoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchSoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchSoftmaxGrad
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchSoftmaxGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kSoftmax
>
(
stream
,
load_y
,
load_dy
,
store
,
...
...
@@ -1318,8 +1343,8 @@ DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE s
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLogSoftmaxGrad
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
inline
typename
std
::
enable_if
<!
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLogSoftmaxGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
if
(
cols
<=
1024
)
{
return
DispatchSoftmaxGradWarpImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
...
...
@@ -1327,23 +1352,23 @@ DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STOR
}
else
{
bool
dispatch_smem_impl_success
;
{
cuda
Error_t
err
=
TryDispatchSoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
GPU
(
Error_t
)
err
=
TryDispatchSoftmaxGradBlockSMemImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
,
&
dispatch_smem_impl_success
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
!
dispatch_smem_impl_success
)
{
return
DispatchSoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
template
<
typename
LOAD_Y
,
typename
LOAD_DY
,
typename
STORE
,
typename
ComputeType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
cuda
Error_t
>::
type
DispatchLogSoftmaxGrad
(
cuda
Stream_t
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
inline
typename
std
::
enable_if
<
std
::
is_same
<
ComputeType
,
double
>::
value
,
GPU
(
Error_t
)
>::
type
DispatchLogSoftmaxGrad
(
GPU
(
Stream_t
)
stream
,
LOAD_Y
load_y
,
LOAD_DY
load_dy
,
STORE
store
,
const
int64_t
rows
,
const
int64_t
cols
)
{
return
DispatchSoftmaxGradBlockUncachedImpl
<
LOAD_Y
,
LOAD_DY
,
STORE
,
ComputeType
,
Algorithm
::
kLogSoftmax
>
(
stream
,
load_y
,
load_dy
,
...
...
oneflow/core/cuda/unique.cuh
View file @
a715222c
...
...
@@ -16,8 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_UNIQUE_H_
#define ONEFLOW_CORE_CUDA_UNIQUE_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#include <device_launch_parameters.h>
#endif
#include "oneflow/core/common/permutation_iterator.h"
#include "oneflow/core/common/not_equal_to_previous_adjacent_iterator.h"
...
...
@@ -49,82 +55,98 @@ __device__ __host__ __forceinline__ T* PtrOffset(void* ptr, size_t offset) {
__device__
__host__
__forceinline__
size_t
max
(
size_t
a
,
size_t
b
)
{
return
a
>
b
?
a
:
b
;
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
DoUnique
(
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
GPU
(
Error_t
)
DoUnique
(
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
void
*
workspace
,
size_t
*
workspace_size
,
GPU
(
Stream_t
)
stream
)
{
size_t
ws
=
*
workspace_size
;
cudaError_t
err
=
cub
::
DeviceSelect
::
Unique
<
const
Key
*
,
Key
*
,
Index
*>
(
#ifdef WITH_ROCM
GPU
(
Error_t
)
err
=
hipcub
::
DeviceSelect
::
Unique
<
const
Key
*
,
Key
*
,
Index
*>
(
workspace
,
ws
,
sorted_in
,
unique
,
num_unique
,
n
,
stream
);
#else
GPU
(
Error_t
)
err
=
cub
::
DeviceSelect
::
Unique
<
const
Key
*
,
Key
*
,
Index
*>
(
workspace
,
ws
,
sorted_in
,
unique
,
num_unique
,
n
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
#endif
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
ws
;
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
DoUniqueWithCounts
(
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
GPU
(
Error_t
)
DoUniqueWithCounts
(
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
Index
*
counts
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
GPU
(
Stream_t
)
stream
)
{
size_t
ws
=
*
workspace_size
;
cudaError_t
err
=
cub
::
DeviceRunLengthEncode
::
Encode
<
const
Key
*
,
Key
*
,
Index
*
,
Index
*>
(
#ifdef WITH_ROCM
GPU
(
Error_t
)
err
=
hipcub
::
DeviceRunLengthEncode
::
Encode
<
const
Key
*
,
Key
*
,
Index
*
,
Index
*>
(
workspace
,
ws
,
sorted_in
,
unique
,
counts
,
num_unique
,
n
,
stream
);
#else
GPU
(
Error_t
)
err
=
cub
::
DeviceRunLengthEncode
::
Encode
<
const
Key
*
,
Key
*
,
Index
*
,
Index
*>
(
workspace
,
ws
,
sorted_in
,
unique
,
counts
,
num_unique
,
n
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
#endif
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
ws
;
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
DispatchOutputCounts
(
Flag
flag
,
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
GPU
(
Error_t
)
DispatchOutputCounts
(
Flag
flag
,
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
Index
*
counts
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
size_t
*
workspace_size
,
GPU
(
Stream_t
)
stream
)
{
size_t
ws
=
*
workspace_size
;
if
((
flag
&
kOutputCounts
)
!=
0
)
{
cuda
Error_t
err
=
DoUniqueWithCounts
<
Key
,
Index
>
(
n
,
sorted_in
,
unique
,
num_unique
,
counts
,
GPU
(
Error_t
)
err
=
DoUniqueWithCounts
<
Key
,
Index
>
(
n
,
sorted_in
,
unique
,
num_unique
,
counts
,
workspace
,
&
ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
else
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
DoUnique
<
Key
,
Index
>
(
n
,
sorted_in
,
unique
,
num_unique
,
workspace
,
&
ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
ws
;
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
Key
,
typename
Index
,
typename
InverseIndicesIter
>
cuda
Error_t
DoGenInverseIndices
(
size_t
n
,
const
Key
*
sorted_in
,
GPU
(
Error_t
)
DoGenInverseIndices
(
size_t
n
,
const
Key
*
sorted_in
,
InverseIndicesIter
inverse_indices_iter
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
size_t
*
workspace_size
,
GPU
(
Stream_t
)
stream
)
{
size_t
ws
=
*
workspace_size
;
NotEqualToPreviousAdjacentIterator
<
Index
,
Key
>
unique_counting_iter
(
sorted_in
,
0
);
cudaError_t
err
=
#ifdef WITH_ROCM
GPU
(
Error_t
)
err
=
hipcub
::
DeviceScan
::
InclusiveSum
<
decltype
(
unique_counting_iter
),
InverseIndicesIter
>
(
workspace
,
ws
,
unique_counting_iter
,
inverse_indices_iter
,
n
,
stream
);
#else
GPU
(
Error_t
)
err
=
cub
::
DeviceScan
::
InclusiveSum
<
decltype
(
unique_counting_iter
),
InverseIndicesIter
>
(
workspace
,
ws
,
unique_counting_iter
,
inverse_indices_iter
,
n
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
#endif
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
ws
;
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
Key
,
typename
Index
,
typename
InverseIndicesIter
>
cuda
Error_t
DispatchOutputInverseIndices
(
Flag
flag
,
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
GPU
(
Error_t
)
DispatchOutputInverseIndices
(
Flag
flag
,
size_t
n
,
const
Key
*
sorted_in
,
Key
*
unique
,
Index
*
num_unique
,
InverseIndicesIter
inverse_indices_iter
,
Index
*
counts
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
GPU
(
Stream_t
)
stream
)
{
size_t
dispatch_with_counts_ws
=
*
workspace_size
;
size_t
do_gen_inverse_indices_ws
=
*
workspace_size
;
{
cuda
Error_t
err
=
GPU
(
Error_t
)
err
=
DispatchOutputCounts
<
Key
,
Index
>
(
flag
,
n
,
sorted_in
,
unique
,
num_unique
,
counts
,
workspace
,
&
dispatch_with_counts_ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
((
flag
&
kOutputInverseIndices
)
!=
0
)
{
cuda
Error_t
err
=
DoGenInverseIndices
<
Key
,
Index
,
InverseIndicesIter
>
(
GPU
(
Error_t
)
err
=
DoGenInverseIndices
<
Key
,
Index
,
InverseIndicesIter
>
(
n
,
sorted_in
,
inverse_indices_iter
,
workspace
,
&
do_gen_inverse_indices_ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
max
(
dispatch_with_counts_ws
,
do_gen_inverse_indices_ws
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
T
>
...
...
@@ -136,8 +158,8 @@ __global__ void IotaKernel(size_t n, T* out) {
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
DoSort
(
size_t
n
,
const
Key
*
in
,
Key
*
sorted
,
Index
*
sorted_indices
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
GPU
(
Error_t
)
DoSort
(
size_t
n
,
const
Key
*
in
,
Key
*
sorted
,
Index
*
sorted_indices
,
void
*
workspace
,
size_t
*
workspace_size
,
GPU
(
Stream_t
)
stream
)
{
Index
*
indices
;
const
size_t
indices_size
=
GetCudaAlignedSize
(
n
*
sizeof
(
Index
));
void
*
sort_workspace
;
...
...
@@ -147,7 +169,7 @@ cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices,
sort_workspace
=
nullptr
;
sort_ws
=
0
;
}
else
{
if
(
*
workspace_size
<=
indices_size
)
{
return
cuda
ErrorInvalidValue
;
}
if
(
*
workspace_size
<=
indices_size
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
indices
=
PtrOffset
<
Index
>
(
workspace
,
0
);
sort_workspace
=
PtrOffset
<
Index
>
(
workspace
,
indices_size
);
sort_ws
=
*
workspace_size
-
indices_size
;
...
...
@@ -157,17 +179,22 @@ cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices,
const
int
num_blocks
=
static_cast
<
int
>
((
n
+
block_size
-
1
)
/
block_size
);
IotaKernel
<
Index
><<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
n
,
indices
);
}
cudaError_t
err
=
cub
::
DeviceRadixSort
::
SortPairs
<
Key
,
Index
>
(
#ifdef WITH_ROCM
GPU
(
Error_t
)
err
=
hipcub
::
DeviceRadixSort
::
SortPairs
<
Key
,
Index
>
(
sort_workspace
,
sort_ws
,
in
,
sorted
,
indices
,
sorted_indices
,
n
,
0
,
sizeof
(
Key
)
*
8
,
stream
);
#else
GPU
(
Error_t
)
err
=
cub
::
DeviceRadixSort
::
SortPairs
<
Key
,
Index
>
(
sort_workspace
,
sort_ws
,
in
,
sorted
,
indices
,
sorted_indices
,
n
,
0
,
sizeof
(
Key
)
*
8
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
#endif
if
(
err
!=
GPU
(
Success
))
{
return
err
;
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
indices_size
+
sort_ws
;
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
DispatchInputSorted
(
Flag
flag
,
size_t
n
,
const
Key
*
in
,
Key
*
unique
,
Index
*
num_unique
,
GPU
(
Error_t
)
DispatchInputSorted
(
Flag
flag
,
size_t
n
,
const
Key
*
in
,
Key
*
unique
,
Index
*
num_unique
,
Index
*
inverse_indices
,
Index
*
counts
,
void
*
workspace
,
size_t
*
workspace_size
,
cuda
Stream_t
stream
)
{
size_t
*
workspace_size
,
GPU
(
Stream_t
)
stream
)
{
if
((
flag
&
kInputSorted
)
!=
0
)
{
return
DispatchOutputInverseIndices
<
Key
,
Index
,
Index
*>
(
flag
,
n
,
in
,
unique
,
num_unique
,
inverse_indices
,
counts
,
workspace
,
...
...
@@ -190,7 +217,7 @@ cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique,
do_inverse_indices_ws
=
0
;
do_inverse_indices_workspace
=
nullptr
;
}
else
{
if
(
*
workspace_size
<=
sort_buffer_size
)
{
return
cuda
ErrorInvalidValue
;
}
if
(
*
workspace_size
<=
sort_buffer_size
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
sorted_in
=
PtrOffset
<
Key
>
(
workspace
,
0
);
sorted_indices
=
PtrOffset
<
Index
>
(
workspace
,
sorted_in_size
);
do_sort_ws
=
*
workspace_size
-
sort_buffer_size
;
...
...
@@ -199,38 +226,38 @@ cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique,
do_inverse_indices_workspace
=
do_sort_workspace
;
}
{
cuda
Error_t
err
=
DoSort
<
Key
,
Index
>
(
n
,
in
,
sorted_in
,
sorted_indices
,
do_sort_workspace
,
GPU
(
Error_t
)
err
=
DoSort
<
Key
,
Index
>
(
n
,
in
,
sorted_in
,
sorted_indices
,
do_sort_workspace
,
&
do_sort_ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
PermutationIterator
<
Index
,
Index
*
,
Index
*>
inverse_indices_iter
(
inverse_indices
,
sorted_indices
);
{
cuda
Error_t
err
=
DispatchOutputInverseIndices
<
Key
,
Index
,
decltype
(
inverse_indices_iter
)
>
(
GPU
(
Error_t
)
err
=
DispatchOutputInverseIndices
<
Key
,
Index
,
decltype
(
inverse_indices_iter
)
>
(
flag
,
n
,
sorted_in
,
unique
,
num_unique
,
inverse_indices_iter
,
counts
,
do_inverse_indices_workspace
,
&
do_inverse_indices_ws
,
stream
);
if
(
err
!=
cuda
Success
)
{
return
err
;
}
if
(
err
!=
GPU
(
Success
)
)
{
return
err
;
}
}
if
(
*
workspace_size
==
0
)
{
*
workspace_size
=
sort_buffer_size
+
max
(
do_sort_ws
,
do_inverse_indices_ws
);
}
return
cuda
Success
;
return
GPU
(
Success
)
;
}
}
}
// namespace
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
Launch
(
Flag
flag
,
size_t
n
,
const
Key
*
in
,
Key
*
unique
,
Index
*
num_unique
,
GPU
(
Error_t
)
Launch
(
Flag
flag
,
size_t
n
,
const
Key
*
in
,
Key
*
unique
,
Index
*
num_unique
,
Index
*
inverse_indices
,
Index
*
counts
,
void
*
workspace
,
size_t
workspace_size
,
cuda
Stream_t
stream
)
{
if
(
workspace_size
==
0
)
{
return
cuda
ErrorInvalidValue
;
}
GPU
(
Stream_t
)
stream
)
{
if
(
workspace_size
==
0
)
{
return
GPU
(
ErrorInvalidValue
)
;
}
return
DispatchInputSorted
<
Key
,
Index
>
(
flag
,
n
,
in
,
unique
,
num_unique
,
inverse_indices
,
counts
,
workspace
,
&
workspace_size
,
stream
);
}
template
<
typename
Key
,
typename
Index
>
cuda
Error_t
GetWorkspaceSize
(
Flag
flag
,
size_t
n
,
size_t
*
workspace_size
)
{
GPU
(
Error_t
)
GetWorkspaceSize
(
Flag
flag
,
size_t
n
,
size_t
*
workspace_size
)
{
*
workspace_size
=
0
;
return
DispatchInputSorted
<
Key
,
Index
>
(
flag
,
n
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
workspace_size
,
0
);
...
...
oneflow/core/device/cuda_util.cpp
View file @
a715222c
...
...
@@ -23,11 +23,7 @@ limitations under the License.
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/platform/include/pthread_fork.h"
#include "oneflow/core/device/device_context.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/vm/vm_util.h"
#ifdef WITH_CUDA
...
...
@@ -193,6 +189,10 @@ Maybe<double> GetCUDAMemoryUsed() {
int
deviceCount
=
0
;
cudaError_t
error_id
=
cudaGetDeviceCount
(
&
deviceCount
);
if
(
error_id
!=
cudaSuccess
)
{
return
Error
::
RuntimeError
()
<<
"Error: GetCUDAMemoryUsed fails :"
<<
cudaGetErrorString
(
error_id
);
}
CHECK_OR_RETURN
(
deviceCount
>
0
)
<<
"GPU device does not exist"
;
...
...
@@ -209,6 +209,26 @@ Maybe<double> GetCUDAMemoryUsed() {
return
(
total_memory
-
free_memory
);
}
static
std
::
once_flag
prop_init_flag
;
static
std
::
vector
<
cudaDeviceProp
>
device_props
;
void
InitDevicePropVectorSize
()
{
int
device_count
=
GetCudaDeviceCount
();
device_props
.
resize
(
device_count
);
}
void
InitDeviceProperties
(
int
device_id
)
{
std
::
call_once
(
prop_init_flag
,
InitDevicePropVectorSize
);
cudaDeviceProp
prop
{};
OF_CUDA_CHECK
(
cudaGetDeviceProperties
(
&
prop
,
device_id
));
device_props
[
device_id
]
=
prop
;
}
cudaDeviceProp
*
GetDeviceProperties
(
int
device_id
)
{
InitCudaContextOnce
(
device_id
);
return
&
device_props
[
device_id
];
}
void
InitCudaContextOnce
(
int
device_id
)
{
static
int
device_count
=
GetCudaDeviceCount
();
static
std
::
vector
<
std
::
once_flag
>
init_flags
=
std
::
vector
<
std
::
once_flag
>
(
device_count
);
...
...
@@ -217,6 +237,7 @@ void InitCudaContextOnce(int device_id) {
std
::
call_once
(
init_flags
[
device_id
],
[
&
]()
{
OF_CUDA_CHECK
(
cudaSetDevice
(
device_id
));
OF_CUDA_CHECK
(
cudaDeviceSynchronize
());
InitDeviceProperties
(
device_id
);
});
}
...
...
@@ -361,6 +382,10 @@ Maybe<double> GetCUDAMemoryUsed() {
int
deviceCount
=
0
;
hipError_t
error_id
=
hipGetDeviceCount
(
&
deviceCount
);
if
(
error_id
!=
hipSuccess
)
{
return
Error
::
RuntimeError
()
<<
"Error: GetCUDAMemoryUsed fails :"
<<
hipGetErrorString
(
error_id
);
}
CHECK_OR_RETURN
(
deviceCount
>
0
)
<<
"GPU device does not exist"
;
...
...
@@ -377,6 +402,26 @@ Maybe<double> GetCUDAMemoryUsed() {
return
(
total_memory
-
free_memory
);
}
static
std
::
once_flag
prop_init_flag
;
static
std
::
vector
<
hipDeviceProp_t
>
device_props
;
void
InitDevicePropVectorSize
()
{
int
device_count
=
GetCudaDeviceCount
();
device_props
.
resize
(
device_count
);
}
void
InitDeviceProperties
(
int
device_id
)
{
std
::
call_once
(
prop_init_flag
,
InitDevicePropVectorSize
);
hipDeviceProp_t
prop
{};
OF_CUDA_CHECK
(
hipGetDeviceProperties
(
&
prop
,
device_id
));
device_props
[
device_id
]
=
prop
;
}
hipDeviceProp_t
*
GetDeviceProperties
(
int
device_id
)
{
InitCudaContextOnce
(
device_id
);
return
&
device_props
[
device_id
];
}
void
InitCudaContextOnce
(
int
device_id
)
{
static
int
device_count
=
GetCudaDeviceCount
();
static
std
::
vector
<
std
::
once_flag
>
init_flags
=
std
::
vector
<
std
::
once_flag
>
(
device_count
);
...
...
@@ -385,11 +430,10 @@ void InitCudaContextOnce(int device_id) {
std
::
call_once
(
init_flags
[
device_id
],
[
&
]()
{
OF_CUDA_CHECK
(
hipSetDevice
(
device_id
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
InitDeviceProperties
(
device_id
);
});
}
#endif // WITH_ROCM
}
// namespace oneflow
oneflow/core/device/cuda_util.cu
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cub/cub.cuh>
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
int
GetCudaSmVersion
()
{
int
sm_version
,
device_ordinal
;
OF_CUDA_CHECK
(
cudaGetDevice
(
&
device_ordinal
));
OF_CUDA_CHECK
(
cub
::
SmVersion
(
sm_version
,
device_ordinal
));
return
sm_version
;
}
int
GetCudaPtxVersion
()
{
int
ptx_version
;
OF_CUDA_CHECK
(
cub
::
PtxVersion
(
ptx_version
));
return
ptx_version
;
}
}
// namespace oneflow
oneflow/core/device/cuda_util.h
View file @
a715222c
...
...
@@ -30,6 +30,9 @@ limitations under the License.
#include <curand.h>
#include <nccl.h>
#include <cuda_fp16.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_pseudo_half.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
...
...
@@ -82,7 +85,10 @@ const char* NvjpegGetErrorString(nvjpegStatus_t error);
#define OF_NCCL_CHECK_OR_RETURN(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
return Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) << " (" \
<< _of_nccl_check_status << ") "
...
...
@@ -152,16 +158,14 @@ class CublasMathModeGuard final {
cublasMath_t
new_mode_
{};
};
int
GetCudaSmVersion
();
int
GetCudaPtxVersion
();
int
GetCudaDeviceIndex
();
int
GetCudaDeviceCount
();
Maybe
<
double
>
GetCUDAMemoryUsed
();
cudaDeviceProp
*
GetDeviceProperties
(
int
device_id
);
void
SetCudaDeviceIndex
(
int
device_id
);
void
CudaSynchronize
(
int
device_id
);
...
...
@@ -184,7 +188,11 @@ cudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active);
#include <rccl.h>
#include <hip/hip_fp16.h>
#include "oneflow/core/device/cuda_pseudo_half.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
namespace
oneflow
{
...
...
@@ -223,7 +231,10 @@ const char* CurandGetErrorString(hiprandStatus_t error);
#define OF_NCCL_CHECK_OR_RETURN(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
return Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) << " (" \
<< _of_nccl_check_status << ") "
...
...
@@ -275,6 +286,8 @@ int GetCudaDeviceCount();
Maybe
<
double
>
GetCUDAMemoryUsed
();
hipDeviceProp_t
*
GetDeviceProperties
(
int
device_id
);
void
SetCudaDeviceIndex
(
int
device_id
);
void
CudaSynchronize
(
int
device_id
);
...
...
oneflow/core/device/cudnn_conv_util.cpp
View file @
a715222c
...
...
@@ -341,7 +341,10 @@ ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args)
}
ManagedCudnnConvResource
::~
ManagedCudnnConvResource
()
{
if
(
handle_
!=
nullptr
)
{
OF_CUDNN_CHECK
(
cudnnDestroy
(
handle_
));
}
if
(
handle_
!=
nullptr
)
{
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle_
);
handle_
=
nullptr
;
}
if
(
x_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
cudaFree
(
x_dptr_
));
}
if
(
w_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
cudaFree
(
w_dptr_
));
}
if
(
y_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
cudaFree
(
y_dptr_
));
}
...
...
@@ -349,7 +352,7 @@ ManagedCudnnConvResource::~ManagedCudnnConvResource() {
}
cudnnHandle_t
ManagedCudnnConvResource
::
cudnn_handle
()
{
if
(
handle_
==
nullptr
)
{
OF_CUDNN_CHECK
(
cudnnCreate
(
&
handle_
)
);
}
if
(
handle_
==
nullptr
)
{
handle_
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
(
);
}
return
handle_
;
}
...
...
@@ -392,7 +395,12 @@ bool operator==(const CudnnConvParams& a, const CudnnConvParams& b) {
}
DataType
GetConvDescDataType
(
DataType
data_type
,
bool
pseudo_half
)
{
return
(
data_type
==
DataType
::
kFloat16
&&
pseudo_half
)
?
DataType
::
kFloat
:
data_type
;
if
(
data_type
==
DataType
::
kFloat16
&&
pseudo_half
)
{
return
DataType
::
kFloat
;
}
else
if
(
data_type
==
DataType
::
kBFloat16
)
{
return
DataType
::
kFloat
;
}
return
data_type
;
}
cudnnStatus_t
GetCudnnConvWorkspaceSize
(
const
CudnnConvArgs
&
args
,
CudnnConvResource
*
res
,
...
...
@@ -669,25 +677,6 @@ perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res,
<<
") requires memory "
<<
perf_vec
[
0
].
memory
;
}
// #if HIPDNN_VERSION < 7500
// // google [blacklist fft algorithms for strided dgrad]
// if (std::is_same<decltype(perf_vec[found_algo_idx].algo), hipdnnConvolutionBwdDataAlgo_t>::value) {
// int stride_dim = args.params.x_ndim - 2;
// bool blacklist =
// std::any_of(std::begin(args.params.stride), std::begin(args.params.stride) + stride_dim,
// [](int n) { return n != 1; });
// if (blacklist
// && (static_cast<hipdnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)
// == HIPDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
// || static_cast<hipdnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)
// == HIPDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
// perf_t algo_perf;
// SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo<algo_t>());
// return algo_perf;
// }
// }
// #endif
return
perf_vec
.
at
(
found_algo_idx
);
}
...
...
@@ -864,39 +853,37 @@ CudnnConvArgs::CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_t
wdesc
(
w_data_type
,
w_shape
,
data_format
),
cdesc
(
GetConvDescDataType
(
x_data_type
,
enable_pseudo_half
),
x_data_type
,
x_shape
,
ctx
),
heuristic
(
heuristic_search
),
deterministic
(
use_deterministic_algo_only
)
{
std
::
memset
(
&
params
,
0
,
sizeof
(
CudnnConvParams
));
OF_CUDNN_CHECK
(
hipdnnGetTensorNdDescriptor
(
xdesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
x_data_type
,
&
params
.
x_ndim
,
params
.
x_dims
,
params
.
x_strides
));
OF_CUDNN_CHECK
(
hipdnnGetTensorNdDescriptor
(
ydesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
y_data_type
,
&
params
.
y_ndim
,
params
.
y_dims
,
params
.
y_strides
));
OF_CUDNN_CHECK
(
hipdnnGetFilterNdDescriptor
(
wdesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
w_data_type
,
&
params
.
w_format
,
&
params
.
w_ndim
,
params
.
w_dims
));
hipdnnConvolutionMode_t
mode
;
int
conv_dim_size
=
x_shape
.
NumAxes
()
-
2
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
params
.
padding
[
i
]
=
cdesc
.
CD_padding
[
i
];
params
.
stride
[
i
]
=
cdesc
.
CD_stride
[
i
];
params
.
dilation
[
i
]
=
cdesc
.
CD_dilation
[
i
];
}
deterministic
(
use_deterministic_algo_only
),
max_ws_size
(
max_workspace_size
)
{
// std::memset(¶ms, 0, sizeof(CudnnConvParams));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.x_data_type, ¶ms.x_ndim, params.x_dims,
// params.x_strides));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.y_data_type, ¶ms.y_ndim, params.y_dims,
// params.y_strides));
// OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.w_data_type, ¶ms.w_format, ¶ms.w_ndim,
// params.w_dims));
// hipdnnConvolutionMode_t mode;
// int conv_dim_size = x_shape.NumAxes() - 2;
// for (int i=0; i<3; i++) {
// params.padding[i] = cdesc.CD_padding[i];
// params.stride[i] = cdesc.CD_stride[i];
// params.dilation[i] = cdesc.CD_dilation[i];
// }
mode
=
cdesc
.
CD_mode
;
params
.
data_type
=
cdesc
.
CD_data_type
;
//
mode = cdesc.CD_mode;
//
params.data_type = cdesc.CD_data_type;
//
OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,
//
&conv_dim_size, params.padding
, params.
stride,
//
params.dilation, &mode
,
&
params.
data_type)
);
CHECK_EQ
(
params
.
x_data_type
,
params
.
w_data_type
);
CHECK_EQ
(
params
.
x_ndim
,
params
.
w_ndim
)
;
//
CHECK_EQ(conv_dim_size + 2, params.x_ndim)
;
//
CHECK_EQ(params.x_data_type, params.w_data_type);
//
CHECK_EQ(params.x_ndim
, params.
w_ndim);
//
// CHECK_EQ(conv_dim_size + 2
, params.
x_ndim
);
// params.groups = cdesc.CD_groups
;
//
params.max_ws_size = max_workspace_size
;
params
.
groups
=
cdesc
.
CD_groups
;
// OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), ¶ms.groups));
params
.
max_ws_size
=
max_workspace_size
;
}
CudnnConvArgs
::
CudnnConvArgs
(
const
user_op
::
KernelComputeContext
&
ctx
,
DataType
x_data_type
,
...
...
@@ -910,51 +897,56 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType
wdesc
(
w_data_type
,
w_shape
,
data_format
),
cdesc
(
GetConvDescDataType
(
x_data_type
,
enable_pseudo_half
),
x_data_type
,
x_shape
,
ctx
),
heuristic
(
heuristic_search
),
deterministic
(
use_deterministic_algo_only
)
{
std
::
memset
(
&
params
,
0
,
sizeof
(
CudnnConvParams
));
OF_CUDNN_CHECK
(
hipdnnGetTensorNdDescriptor
(
xdesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
x_data_type
,
&
params
.
x_ndim
,
params
.
x_dims
,
params
.
x_strides
));
OF_CUDNN_CHECK
(
hipdnnGetTensorNdDescriptor
(
ydesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
y_data_type
,
&
params
.
y_ndim
,
params
.
y_dims
,
params
.
y_strides
));
OF_CUDNN_CHECK
(
hipdnnGetFilterNdDescriptor
(
wdesc
.
Get
(),
CudnnConvParams
::
kTensorMaxDims
,
&
params
.
w_data_type
,
&
params
.
w_format
,
&
params
.
w_ndim
,
params
.
w_dims
));
hipdnnConvolutionMode_t
mode
;
int
conv_dim_size
=
x_shape
.
NumAxes
()
-
2
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
params
.
padding
[
i
]
=
cdesc
.
CD_padding
[
i
];
params
.
stride
[
i
]
=
cdesc
.
CD_stride
[
i
];
params
.
dilation
[
i
]
=
cdesc
.
CD_dilation
[
i
];
}
deterministic
(
use_deterministic_algo_only
),
max_ws_size
(
max_workspace_size
)
{
// std::memset(¶ms, 0, sizeof(CudnnConvParams));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.x_data_type, ¶ms.x_ndim, params.x_dims,
// params.x_strides));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.y_data_type, ¶ms.y_ndim, params.y_dims,
// params.y_strides));
// OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
// ¶ms.w_data_type, ¶ms.w_format, ¶ms.w_ndim,
// params.w_dims));
// hipdnnConvolutionMode_t mode;
// int conv_dim_size = x_shape.NumAxes() - 2;
// for (int i=0; i<3; i++) {
// params.padding[i] = cdesc.CD_padding[i];
// params.stride[i] = cdesc.CD_stride[i];
// params.dilation[i] = cdesc.CD_dilation[i];
// }
mode
=
cdesc
.
CD_mode
;
params
.
data_type
=
cdesc
.
CD_data_type
;
//
mode = cdesc.CD_mode;
//
params.data_type = cdesc.CD_data_type;
//
OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,
//
&conv_dim_size, params.padding
, params.
stride,
//
params.dilation, &mode
,
&
params.
data_type)
);
CHECK_EQ
(
params
.
x_data_type
,
params
.
w_data_type
);
CHECK_EQ
(
params
.
x_ndim
,
params
.
w_ndim
)
;
//
CHECK_EQ(conv_dim_size + 2, params.x_ndim)
;
//
CHECK_EQ(params.x_data_type, params.w_data_type);
//
CHECK_EQ(params.x_ndim
, params.
w_ndim);
//
// CHECK_EQ(conv_dim_size + 2
, params.
x_ndim
);
// params.groups = cdesc.CD_groups
;
//
params.max_ws_size = max_workspace_size
;
params
.
groups
=
cdesc
.
CD_groups
;
// OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), ¶ms.groups));
params
.
max_ws_size
=
max_workspace_size
;
}
ManagedCudnnConvResource
::
ManagedCudnnConvResource
(
const
CudnnConvArgs
&
args
)
:
handle_
(
nullptr
),
x_dptr_
(
nullptr
),
w_dptr_
(
nullptr
),
y_dptr_
(
nullptr
),
ws_dptr_
(
nullptr
)
{
x_byte_size_
=
ByteSize4Tensor
(
args
.
params
.
x_dims
,
args
.
params
.
x_ndim
,
args
.
params
.
x_data_type
);
w_byte_size_
=
ByteSize4Tensor
(
args
.
params
.
w_dims
,
args
.
params
.
w_ndim
,
args
.
params
.
w_data_type
);
y_byte_size_
=
ByteSize4Tensor
(
args
.
params
.
y_dims
,
args
.
params
.
y_ndim
,
args
.
params
.
y_data_type
);
ws_byte_size_
=
args
.
params
.
max_ws_size
;
// x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type);
// w_byte_size_ = ByteSize4Tensor(args.params.w_dims, args.params.w_ndim, args.params.w_data_type);
// y_byte_size_ = ByteSize4Tensor(args.params.y_dims, args.params.y_ndim, args.params.y_data_type);
// ws_byte_size_ = args.params.max_ws_size;
x_byte_size_
=
0
;
w_byte_size_
=
0
;
y_byte_size_
=
0
;
ws_byte_size_
=
0
;
}
ManagedCudnnConvResource
::~
ManagedCudnnConvResource
()
{
if
(
handle_
!=
nullptr
)
{
OF_CUDNN_CHECK
(
hipdnnDestroy
(
handle_
));
}
if
(
handle_
!=
nullptr
)
{
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle_
);
handle_
=
nullptr
;
}
if
(
x_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
hipFree
(
x_dptr_
));
}
if
(
w_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
hipFree
(
w_dptr_
));
}
if
(
y_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
hipFree
(
y_dptr_
));
}
...
...
@@ -962,7 +954,7 @@ ManagedCudnnConvResource::~ManagedCudnnConvResource() {
}
hipdnnHandle_t
ManagedCudnnConvResource
::
cudnn_handle
()
{
if
(
handle_
==
nullptr
)
{
OF_CUDNN_CHECK
(
hipdnnCreate
(
&
handle_
)
);
}
if
(
handle_
==
nullptr
)
{
handle_
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
(
);
}
return
handle_
;
}
...
...
@@ -1005,7 +997,12 @@ bool operator==(const CudnnConvParams& a, const CudnnConvParams& b) {
}
DataType
GetConvDescDataType
(
DataType
data_type
,
bool
pseudo_half
)
{
return
(
data_type
==
DataType
::
kFloat16
&&
pseudo_half
)
?
DataType
::
kFloat
:
data_type
;
if
(
data_type
==
DataType
::
kFloat16
&&
pseudo_half
)
{
return
DataType
::
kFloat
;
}
else
if
(
data_type
==
DataType
::
kBFloat16
)
{
return
DataType
::
kFloat
;
}
return
data_type
;
}
hipdnnStatus_t
GetCudnnConvWorkspaceSize
(
const
CudnnConvArgs
&
args
,
CudnnConvResource
*
res
,
...
...
@@ -1035,38 +1032,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionFwdAlgoPerf_t> {
static
int
GetAlgoMaxCount
(
CudnnConvResource
*
res
)
{
int
max_algo_cnt
=
1
;
// OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return
max_algo_cnt
;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
// res->cudnn_handle(), args.xdesc.Get(), args.wdesc.Get(), args.cdesc.Get(), args.ydesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static
void
ExhaustiveSearch
(
CudnnConvArgs
&
args
,
CudnnConvResource
*
res
,
perf_t
*
perf
)
{
int
found_algo_cnt
=
0
;
size_t
ws
=
0
;
hipdnnConvolutionFwdAlgo_t
algo
;
hipdnnGetConvolutionForwardWorkspaceSize
(
res
->
cudnn_handle
(),
args
.
xdesc
.
Get
(),
args
.
wdesc
.
Get
(),
args
.
cdesc
.
Get
(),
args
.
ydesc
.
Get
(),
algo
,
&
ws
);
res
->
ws_byte_size_
=
ws
;
res
->
set_ws
();
args
.
params
.
max_ws_size
=
ws
;
OF_CUDNN_CHECK
(
hipdnnFindConvolutionForwardAlgorithmEx
(
res
->
cudnn_handle
(),
args
.
xdesc
.
Get
(),
res
->
x_const_dptr
(),
args
.
wdesc
.
Get
(),
res
->
w_const_dptr
(),
args
.
cdesc
.
Get
(),
args
.
ydesc
.
Get
(),
res
->
y_mut_dptr
(),
1
,
&
found_algo_cnt
,
perf
,
res
->
ws_dptr
(),
args
.
params
.
max_ws_size
));
args
.
max_ws_size
));
}
};
...
...
@@ -1076,40 +1053,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionBwdDataAlgoPerf_t> {
static
int
GetAlgoMaxCount
(
CudnnConvResource
*
res
)
{
int
max_algo_cnt
=
1
;
// OF_CUDNN_CHECK(
// cudnnGetConvolutionBackwardDataAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return
max_algo_cnt
;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
// res->cudnn_handle(), args.wdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.xdesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static
void
ExhaustiveSearch
(
CudnnConvArgs
&
args
,
CudnnConvResource
*
res
,
perf_t
*
perf
)
{
int
found_algo_cnt
=
0
;
size_t
ws
=
0
;
hipdnnConvolutionBwdDataAlgo_t
algo
;
hipdnnGetConvolutionBackwardDataWorkspaceSize
(
res
->
cudnn_handle
(),
args
.
wdesc
.
Get
(),
args
.
ydesc
.
Get
(),
args
.
cdesc
.
Get
(),
args
.
xdesc
.
Get
(),
algo
,
&
ws
);
res
->
ws_byte_size_
=
ws
;
res
->
set_ws
();
args
.
params
.
max_ws_size
=
ws
;
OF_CUDNN_CHECK
(
hipdnnFindConvolutionBackwardDataAlgorithmEx
(
res
->
cudnn_handle
(),
args
.
wdesc
.
Get
(),
res
->
w_const_dptr
(),
args
.
ydesc
.
Get
(),
res
->
y_const_dptr
(),
args
.
cdesc
.
Get
(),
args
.
xdesc
.
Get
(),
res
->
x_mut_dptr
(),
1
,
&
found_algo_cnt
,
perf
,
res
->
ws_dptr
(),
args
.
params
.
max_ws_size
));
args
.
max_ws_size
));
}
};
...
...
@@ -1119,40 +1074,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionBwdFilterAlgoPerf_t> {
static
int
GetAlgoMaxCount
(
CudnnConvResource
*
res
)
{
int
max_algo_cnt
=
1
;
// OF_CUDNN_CHECK(
// cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return
max_algo_cnt
;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
// res->cudnn_handle(), args.xdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.wdesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static
void
ExhaustiveSearch
(
CudnnConvArgs
&
args
,
CudnnConvResource
*
res
,
perf_t
*
perf
)
{
int
found_algo_cnt
=
0
;
size_t
ws
=
0
;
hipdnnConvolutionBwdFilterAlgo_t
algo
;
hipdnnGetConvolutionBackwardFilterWorkspaceSize
(
res
->
cudnn_handle
(),
args
.
xdesc
.
Get
(),
args
.
ydesc
.
Get
(),
args
.
cdesc
.
Get
(),
args
.
wdesc
.
Get
(),
algo
,
&
ws
);
res
->
ws_byte_size_
=
ws
;
res
->
set_ws
();
args
.
params
.
max_ws_size
=
ws
;
OF_CUDNN_CHECK
(
hipdnnFindConvolutionBackwardFilterAlgorithmEx
(
res
->
cudnn_handle
(),
args
.
xdesc
.
Get
(),
res
->
x_const_dptr
(),
args
.
ydesc
.
Get
(),
res
->
y_const_dptr
(),
args
.
cdesc
.
Get
(),
args
.
wdesc
.
Get
(),
res
->
w_mut_dptr
(),
1
,
&
found_algo_cnt
,
perf
,
res
->
ws_dptr
(),
args
.
params
.
max_ws_size
));
args
.
max_ws_size
));
}
};
...
...
@@ -1200,4 +1133,6 @@ EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(hipdnnConvolutionBwdFilterAlg
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/device/cudnn_conv_util.h
View file @
a715222c
...
...
@@ -308,6 +308,7 @@ struct CudnnConvArgs final {
CudnnConvDesc
cdesc
;
bool
heuristic
;
bool
deterministic
;
size_t
max_ws_size
;
OF_DISALLOW_COPY_AND_MOVE
(
CudnnConvArgs
);
CudnnConvArgs
(
const
user_op
::
InferContext
&
ctx
,
DataType
x_data_type
,
const
ShapeView
&
x_shape
,
...
...
@@ -333,17 +334,14 @@ class CudnnConvResource {
virtual
const
void
*
x_const_dptr
()
const
=
0
;
virtual
const
void
*
y_const_dptr
()
const
=
0
;
virtual
void
*
ws_dptr
()
=
0
;
virtual
void
set_ws
()
=
0
;
size_t
ws_byte_size_
;
};
class
AllocatedCudnnConvResource
final
:
public
CudnnConvResource
{
public:
AllocatedCudnnConvResource
(
hipdnnHandle_t
handle
,
void
*
x_dptr
,
void
*
w_dptr
,
void
*
y_dptr
,
void
*
ws_dptr
,
size_t
ws_byte_size
)
:
handle_
(
handle
),
x_dptr_
(
x_dptr
),
w_dptr_
(
w_dptr
),
y_dptr_
(
y_dptr
),
ws_dptr_
(
ws_dptr
),
ws_byte_size_
(
ws_byte_size
)
{}
// ~AllocatedCudnnConvResource() = default;
~
AllocatedCudnnConvResource
(){
if
(
ws_dptr_
!=
nullptr
)
{
OF_CUDA_CHECK
(
hipFree
(
ws_dptr_
));
}}
void
*
ws_dptr
)
:
handle_
(
handle
),
x_dptr_
(
x_dptr
),
w_dptr_
(
w_dptr
),
y_dptr_
(
y_dptr
),
ws_dptr_
(
ws_dptr
)
{}
~
AllocatedCudnnConvResource
()
=
default
;
hipdnnHandle_t
cudnn_handle
()
override
{
return
handle_
;
}
const
void
*
x_const_dptr
()
const
override
{
return
x_dptr_
;
}
const
void
*
w_const_dptr
()
const
override
{
return
w_dptr_
;
}
...
...
@@ -351,13 +349,7 @@ class AllocatedCudnnConvResource final : public CudnnConvResource {
void
*
x_mut_dptr
()
override
{
return
x_dptr_
;
}
void
*
w_mut_dptr
()
override
{
return
w_dptr_
;
}
void
*
y_mut_dptr
()
override
{
return
y_dptr_
;
}
void
*
ws_dptr
()
override
{
// return ws_dptr_;
if
(
ws_dptr_
==
nullptr
)
{
OF_CUDA_CHECK
(
hipMalloc
(
&
ws_dptr_
,
ws_byte_size_
));
}
return
ws_dptr_
;
}
void
set_ws
()
{
ws_byte_size_
=
CudnnConvResource
::
ws_byte_size_
;
}
size_t
ws_byte_size_
;
void
*
ws_dptr
()
override
{
return
ws_dptr_
;
}
private:
hipdnnHandle_t
handle_
;
...
...
@@ -379,8 +371,6 @@ class ManagedCudnnConvResource final : public CudnnConvResource {
const
void
*
w_const_dptr
()
const
override
;
const
void
*
y_const_dptr
()
const
override
;
void
*
ws_dptr
()
override
;
void
set_ws
(){
ws_byte_size_
=
CudnnConvResource
::
ws_byte_size_
;
}
size_t
ws_byte_size_
;
private:
hipdnnHandle_t
handle_
;
...
...
@@ -391,7 +381,7 @@ class ManagedCudnnConvResource final : public CudnnConvResource {
size_t
x_byte_size_
;
size_t
w_byte_size_
;
size_t
y_byte_size_
;
//
size_t ws_byte_size_;
size_t
ws_byte_size_
;
};
bool
operator
==
(
const
CudnnConvParams
&
a
,
const
CudnnConvParams
&
b
);
...
...
oneflow/core/device/cudnn_util.cpp
View file @
a715222c
...
...
@@ -177,6 +177,45 @@ size_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type) {
return
byte_size
;
}
CudnnHandlePool
::~
CudnnHandlePool
()
{
for
(
auto
&
pair
:
handle_list_map_
)
{
int64_t
device_id
=
pair
.
first
;
auto
&
handle_list
=
pair
.
second
;
CudaCurrentDeviceGuard
guard
(
device_id
);
while
(
!
handle_list
.
empty
())
{
cudnnHandle_t
handle
=
handle_list
.
back
();
handle_list
.
pop_back
();
OF_CUDNN_CHECK
(
cudnnDestroy
(
handle
));
}
}
handle_list_map_
.
clear
();
}
cudnnHandle_t
CudnnHandlePool
::
Get
()
{
int
device_id
;
OF_CUDA_CHECK
(
cudaGetDevice
(
&
device_id
));
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
vector
<
cudnnHandle_t
>&
handle_list
=
handle_list_map_
[
device_id
];
if
(
!
handle_list
.
empty
())
{
cudnnHandle_t
handle
=
handle_list
.
back
();
handle_list
.
pop_back
();
return
handle
;
}
}
cudnnHandle_t
handle
;
OF_CUDNN_CHECK
(
cudnnCreate
(
&
handle
));
return
handle
;
}
void
CudnnHandlePool
::
Put
(
cudnnHandle_t
handle
)
{
int
device_id
;
OF_CUDA_CHECK
(
cudaGetDevice
(
&
device_id
));
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
vector
<
cudnnHandle_t
>&
handle_list
=
handle_list_map_
[
device_id
];
handle_list
.
push_back
(
handle
);
}
#endif // WITH_CUDA
#ifdef WITH_ROCM
...
...
@@ -302,10 +341,6 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
case
HIPDNN_DATA_FLOAT
:
case
HIPDNN_DATA_INT32
:
case
HIPDNN_DATA_INT8x4
:
// case CUDNN_DATA_UINT8x4: {
// byte_size = 4;
// break;
// }
case
HIPDNN_DATA_DOUBLE
:
{
byte_size
=
8
;
break
;
...
...
@@ -315,22 +350,9 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
break
;
}
case
HIPDNN_DATA_INT8
:
{
// case CUDNN_DATA_UINT8: {
byte_size
=
1
;
break
;
}
// #if HIPDNN_VERSION > 7200
// case CUDNN_DATA_INT8x32: {
// byte_size = 32;
// break;
// }
// #endif
// #if HIPDNN_VERSION >= 8100
// case CUDNN_DATA_BFLOAT16: {
// byte_size = 2;
// break;
// }
// #endif
default:
{
UNIMPLEMENTED
();
}
...
...
@@ -338,6 +360,45 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
return
byte_size
;
}
CudnnHandlePool
::~
CudnnHandlePool
()
{
for
(
auto
&
pair
:
handle_list_map_
)
{
int64_t
device_id
=
pair
.
first
;
auto
&
handle_list
=
pair
.
second
;
CudaCurrentDeviceGuard
guard
(
device_id
);
while
(
!
handle_list
.
empty
())
{
hipdnnHandle_t
handle
=
handle_list
.
back
();
handle_list
.
pop_back
();
OF_CUDNN_CHECK
(
hipdnnDestroy
(
handle
));
}
}
handle_list_map_
.
clear
();
}
hipdnnHandle_t
CudnnHandlePool
::
Get
()
{
int
device_id
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_id
));
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
vector
<
hipdnnHandle_t
>&
handle_list
=
handle_list_map_
[
device_id
];
if
(
!
handle_list
.
empty
())
{
hipdnnHandle_t
handle
=
handle_list
.
back
();
handle_list
.
pop_back
();
return
handle
;
}
}
hipdnnHandle_t
handle
;
OF_CUDNN_CHECK
(
hipdnnCreate
(
&
handle
));
return
handle
;
}
void
CudnnHandlePool
::
Put
(
hipdnnHandle_t
handle
)
{
int
device_id
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_id
));
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
vector
<
hipdnnHandle_t
>&
handle_list
=
handle_list_map_
[
device_id
];
handle_list
.
push_back
(
handle
);
}
#endif // WITH_ROCM
template
<
typename
T
>
...
...
@@ -366,4 +427,34 @@ template const void* CudnnSPZeroPtr<float>();
template
const
void
*
CudnnSPZeroPtr
<
double
>();
template
const
void
*
CudnnSPZeroPtr
<
float16
>();
const
void
*
CudnnSPOnePtr
(
const
DataType
dtype
)
{
if
(
dtype
==
kDouble
)
{
return
CudnnSPOnePtr
<
double
>
();
}
else
if
(
dtype
==
kFloat
)
{
return
CudnnSPOnePtr
<
float
>
();
}
else
if
(
dtype
==
kFloat16
)
{
return
CudnnSPOnePtr
<
float16
>
();
}
else
if
(
dtype
==
kBFloat16
)
{
// NOTE(guoran): kBFloat16 use float OnePtr
return
CudnnSPOnePtr
<
float
>
();
}
else
{
UNIMPLEMENTED
();
}
}
const
void
*
CudnnSPZeroPtr
(
const
DataType
dtype
)
{
if
(
dtype
==
kDouble
)
{
return
CudnnSPZeroPtr
<
double
>
();
}
else
if
(
dtype
==
kFloat
)
{
return
CudnnSPZeroPtr
<
float
>
();
}
else
if
(
dtype
==
kFloat16
)
{
return
CudnnSPZeroPtr
<
float16
>
();
}
else
if
(
dtype
==
kBFloat16
)
{
// NOTE(guoran): kBFloat16 use float ZeroPtr
return
CudnnSPZeroPtr
<
float
>
();
}
else
{
UNIMPLEMENTED
();
}
}
}
// namespace oneflow
oneflow/core/device/cudnn_util.h
View file @
a715222c
...
...
@@ -96,6 +96,22 @@ const void* CudnnSPOnePtr();
template
<
typename
T
>
const
void
*
CudnnSPZeroPtr
();
const
void
*
CudnnSPOnePtr
(
const
DataType
dtype
);
const
void
*
CudnnSPZeroPtr
(
const
DataType
dtype
);
class
CudnnHandlePool
{
public:
CudnnHandlePool
()
=
default
;
~
CudnnHandlePool
();
cudnnHandle_t
Get
();
void
Put
(
cudnnHandle_t
handle
);
private:
std
::
mutex
mutex_
;
HashMap
<
int64_t
,
std
::
vector
<
cudnnHandle_t
>>
handle_list_map_
;
};
}
// namespace oneflow
#endif // WITH_CUDA
...
...
@@ -177,6 +193,22 @@ const void* CudnnSPOnePtr();
template
<
typename
T
>
const
void
*
CudnnSPZeroPtr
();
const
void
*
CudnnSPOnePtr
(
const
DataType
dtype
);
const
void
*
CudnnSPZeroPtr
(
const
DataType
dtype
);
class
CudnnHandlePool
{
public:
CudnnHandlePool
()
=
default
;
~
CudnnHandlePool
();
hipdnnHandle_t
Get
();
void
Put
(
hipdnnHandle_t
handle
);
private:
std
::
mutex
mutex_
;
HashMap
<
int64_t
,
std
::
vector
<
hipdnnHandle_t
>>
handle_list_map_
;
};
}
// namespace oneflow
#endif // WITH_ROCM
...
...
oneflow/core/eager/blob_instruction_type.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/eager/blob_instruction_type.h"
#include "oneflow/core/vm/control_stream_type.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/access_blob_arg_cb_phy_instr_operand.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/eager/eager_blob_object.h"
namespace
oneflow
{
namespace
vm
{
void
AccessBlobByCallbackInstructionType
::
Compute
(
vm
::
Instruction
*
instruction
)
const
{
const
auto
&
phy_instr_operand
=
instruction
->
phy_instr_operand
();
CHECK
(
static_cast
<
bool
>
(
phy_instr_operand
));
const
auto
*
ptr
=
dynamic_cast
<
const
vm
::
AccessBlobArgCbPhyInstrOperand
*>
(
phy_instr_operand
.
get
());
CHECK_NOTNULL
(
ptr
);
DeviceCtx
*
device_ctx
=
instruction
->
stream
().
device_ctx
().
get
();
OfBlob
ofblob
(
device_ctx
->
stream
(),
ptr
->
eager_blob_object
()
->
blob
());
ptr
->
callback
()(
reinterpret_cast
<
uint64_t
>
(
&
ofblob
));
}
}
// namespace vm
}
// namespace oneflow
oneflow/core/eager/blob_instruction_type.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/common/stream_role.h"
#include "oneflow/core/common/singleton_ptr.h"
#include "oneflow/core/vm/ep_optional_event_record_status_querier.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/ep_event.h"
#include "oneflow/core/vm/ep_device_context.h"
namespace
oneflow
{
namespace
vm
{
class
AccessBlobByCallbackInstructionType
final
:
public
vm
::
InstructionType
{
public:
AccessBlobByCallbackInstructionType
()
=
default
;
~
AccessBlobByCallbackInstructionType
()
override
=
default
;
std
::
string
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
override
{
return
"AccessBlobByCallback"
;
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
;
};
class
EpRecordEventInstructionType
final
:
public
vm
::
InstructionType
{
public:
EpRecordEventInstructionType
()
=
default
;
~
EpRecordEventInstructionType
()
override
=
default
;
InstructionFuseType
fuse_type
()
const
override
{
return
kEnableInstructionFuseAsTailOnly
;
}
void
InitInstructionStatus
(
Instruction
*
instruction
)
const
override
{
auto
*
status_buffer
=
instruction
->
mut_status_buffer
();
auto
*
stream
=
instruction
->
mut_stream
();
instruction
->
stream_type
().
InitInstructionStatus
(
*
stream
,
status_buffer
);
auto
*
ep_device_ctx
=
static_cast
<
EpDeviceCtx
*>
(
stream
->
device_ctx
().
get
());
auto
*
ep_event_provider
=
ep_device_ctx
->
ep_event_provider
();
const
auto
&
ep_event
=
CHECK_NOTNULL
(
ep_event_provider
)
->
GetReusedEpEvent
();
auto
*
data_ptr
=
status_buffer
->
mut_buffer
();
EpOptionalEventRecordStatusQuerier
::
MutCast
(
data_ptr
)
->
reset_ep_event
(
ep_event
);
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
std
::
string
DebugName
(
const
vm
::
Instruction
&
)
const
override
{
return
"RecordEvent"
;
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{}
};
}
// namespace vm
struct
GetRecordEventInstructionType
:
public
StreamRoleVisitor
<
GetRecordEventInstructionType
>
{
static
Maybe
<
const
vm
::
InstructionType
*>
VisitCompute
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
EpRecordEventInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitHost2Device
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
EpRecordEventInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitDevice2Host
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
EpRecordEventInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitSyncedLaunchedCommNet
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
EpRecordEventInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitAsyncedLaunchedCommNet
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
EpRecordEventInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitBarrier
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitCriticalSection
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitLazyJobLauncher
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitPinnedCompute
(
DeviceType
device_type
)
{
return
VisitCompute
(
device_type
);
}
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
Prev
1
…
14
15
16
17
18
19
20
21
22
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment