Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
8f7de847
Commit
8f7de847
authored
Apr 25, 2023
by
yuguo960516yuguo
Browse files
dtk
parent
f262efc9
Pipeline
#248
failed with stages
in 0 seconds
Changes
121
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2730 additions
and
2730 deletions
+2730
-2730
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
+150
-150
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
...re/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
+109
-109
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
...core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
+396
-396
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_activation_grad.hip.cpp
...tive/broadcast_elementwise_binary_activation_grad.hip.cpp
+39
-39
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_comparision.hip.cpp
...rimitive/broadcast_elementwise_binary_comparision.hip.cpp
+38
-38
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_logical.hip.cpp
...cm/primitive/broadcast_elementwise_binary_logical.hip.cpp
+38
-38
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
.../rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
+35
-35
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
+237
-237
oneflow/core/ep/rocm/primitive/cast.hip.cpp
oneflow/core/ep/rocm/primitive/cast.hip.cpp
+148
-148
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
+254
-254
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
+95
-95
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
+116
-116
oneflow/core/ep/rocm/primitive/fill.hip.cpp
oneflow/core/ep/rocm/primitive/fill.hip.cpp
+151
-151
oneflow/core/ep/rocm/primitive/memcpy.cpp
oneflow/core/ep/rocm/primitive/memcpy.cpp
+62
-62
oneflow/core/ep/rocm/primitive/memset.cpp
oneflow/core/ep/rocm/primitive/memset.cpp
+59
-59
oneflow/core/ep/rocm/primitive/permute.hip.cpp
oneflow/core/ep/rocm/primitive/permute.hip.cpp
+333
-333
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
+107
-107
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
+116
-116
oneflow/core/ep/rocm/primitive/type_seq.h
oneflow/core/ep/rocm/primitive/type_seq.h
+77
-77
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
+170
-170
No files found.
Too many changes to show.
To preserve performance only
121 of 121+
files are displayed.
Plain diff
Email patch
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
template
<
typename
Src
,
typename
Dst
>
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
Src
,
Dst
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
pow
(
src0
,
src1
);
}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
pow
(
src0
,
src1
);
}
};
};
template
<
>
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
}
}
};
};
template
<
>
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
half
,
half
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
half
,
half
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
}
};
};
template
<
typename
Src
,
typename
Dst
>
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#elif defined(__HIP_DEVICE_COMPILE__)
#elif defined(__HIP_DEVICE_COMPILE__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#else
#else
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#endif
#endif
}
}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
static_cast
<
Src
>
(
0.5
)
return
static_cast
<
Src
>
(
0.5
)
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
x
)
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
x
)
+
x
*
coef
*
exp
(
static_cast
<
Src
>
(
-
0.5
)
*
x
*
x
))
+
x
*
coef
*
exp
(
static_cast
<
Src
>
(
-
0.5
)
*
x
*
x
))
*
dy
;
*
dy
;
}
}
Src
coef
;
Src
coef
;
};
};
template
<
typename
Src
,
typename
Dst
>
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
Src
tanh_val
=
tanh
(
x
);
Src
tanh_val
=
tanh
(
x
);
return
static_cast
<
Dst
>
(
dy
*
(
static_cast
<
Src
>
(
1.0
)
-
tanh_val
*
tanh_val
));
return
static_cast
<
Dst
>
(
dy
*
(
static_cast
<
Src
>
(
1.0
)
-
tanh_val
*
tanh_val
));
}
}
};
};
// /*********nv_bfloat16_kernel*******/
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// template<>
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// }
// };
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// } \
// };
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template
<
>
\
template<> \
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
half
,
half
>
{
\
struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
\
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
\
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
float
,
float
>
float_functor
;
\
BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
\
OF_DEVICE_FUNC half operator()(half src0, half src1) const { \
return
__float2half
(
float_functor
(
__half2float
(
src0
),
__half2float
(
src1
)));
\
return __float2half(float_functor(__half2float(src0), __half2float(src1))); \
}
\
} \
};
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardswishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardswishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kMishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kMishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSiluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSiluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftplusBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftplusBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftsignBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftsignBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
);
Scalar
attr1
);
namespace
{
namespace
{
class
BroadcastElementwiseBinaryFactoryImpl
:
public
BroadcastElementwiseBinaryFactory
{
class
BroadcastElementwiseBinaryFactoryImpl
:
public
BroadcastElementwiseBinaryFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryFactoryImpl
);
BroadcastElementwiseBinaryFactoryImpl
()
=
default
;
BroadcastElementwiseBinaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseBinaryFactoryImpl
()
override
=
default
;
~
BroadcastElementwiseBinaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
binary_op
,
DataType
src_type
,
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
binary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
OF_PP_PAIR_FIRST(data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op
,
src_data_type_pair
,
dst_data_type_pair
)
\
binary_op, src_data_type_pair, dst_data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
src_data_type_pair
),
\
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND
(
dst_data_type_pair
)),
\
OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
\
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
},
OF_PP_PAIR_FIRST(dst_data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
OF_PP_PAIR_FIRST(data_type_pair)>},
static
const
std
::
map
<
static
const
std
::
map
<
std
::
tuple
<
BinaryOp
,
DataType
,
DataType
>
,
std
::
tuple
<
BinaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
Scalar
,
Scalar
)
>>
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_binary_handle
{
new_broadcast_elementwise_binary_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
,
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const
auto
it
=
new_broadcast_elementwise_binary_handle
.
find
(
const
auto
it
=
new_broadcast_elementwise_binary_handle
.
find
(
std
::
make_tuple
(
binary_op
,
src_type
,
dst_type
));
std
::
make_tuple
(
binary_op
,
src_type
,
dst_type
));
if
(
it
!=
new_broadcast_elementwise_binary_handle
.
end
())
{
if
(
it
!=
new_broadcast_elementwise_binary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
return
it
->
second
(
attr0
,
attr1
);
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseBinaryFactory
,
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseBinaryFactory
,
BroadcastElementwiseBinaryFactoryImpl
);
BroadcastElementwiseBinaryFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
namespace
{
namespace
{
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
struct
GetPackType
{
struct
GetPackType
{
using
type
=
typename
std
::
aligned_storage
<
N
*
sizeof
(
T
),
N
*
sizeof
(
T
)
>::
type
;
using
type
=
typename
std
::
aligned_storage
<
N
*
sizeof
(
T
),
N
*
sizeof
(
T
)
>::
type
;
};
};
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
using
PackType
=
typename
GetPackType
<
T
,
N
>::
type
;
using
PackType
=
typename
GetPackType
<
T
,
N
>::
type
;
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
union
Pack
{
union
Pack
{
static_assert
(
sizeof
(
PackType
<
T
,
N
>
)
==
sizeof
(
T
)
*
N
,
""
);
static_assert
(
sizeof
(
PackType
<
T
,
N
>
)
==
sizeof
(
T
)
*
N
,
""
);
OF_DEVICE_FUNC
Pack
()
{
OF_DEVICE_FUNC
Pack
()
{
// do nothing
// do nothing
}
}
PackType
<
T
,
N
>
storage
;
PackType
<
T
,
N
>
storage
;
T
elem
[
N
];
T
elem
[
N
];
};
};
template
<
size_t
max_dims
,
typename
IndexType
>
template
<
size_t
max_dims
,
typename
IndexType
>
struct
BroadcastElementwiseBinaryParams
{
struct
BroadcastElementwiseBinaryParams
{
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src0_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src0_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src1_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src1_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
dst_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
dst_index_helper
;
size_t
num_dims
;
size_t
num_dims
;
IndexType
src0_index_mask
[
max_dims
];
IndexType
src0_index_mask
[
max_dims
];
IndexType
src1_index_mask
[
max_dims
];
IndexType
src1_index_mask
[
max_dims
];
IndexType
count
{};
IndexType
count
{};
const
void
*
src0
{};
const
void
*
src0
{};
const
void
*
src1
{};
const
void
*
src1
{};
void
*
dst
{};
void
*
dst
{};
Scalar
attr0
;
Scalar
attr0
;
Scalar
attr1
;
Scalar
attr1
;
};
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
src0_pack_size
,
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
size_t
src1_pack_size
,
typename
IndexType
>
__global__
void
BroadcastElementwiseBinaryGpu
(
__global__
void
BroadcastElementwiseBinaryGpu
(
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
)
{
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
)
{
constexpr
size_t
dst_pack_size
=
constexpr
size_t
dst_pack_size
=
src0_pack_size
>
src1_pack_size
?
src0_pack_size
:
src1_pack_size
;
src0_pack_size
>
src1_pack_size
?
src0_pack_size
:
src1_pack_size
;
static_assert
(
src0_pack_size
==
dst_pack_size
||
src0_pack_size
==
1
,
""
);
static_assert
(
src0_pack_size
==
dst_pack_size
||
src0_pack_size
==
1
,
""
);
static_assert
(
src1_pack_size
==
dst_pack_size
||
src1_pack_size
==
1
,
""
);
static_assert
(
src1_pack_size
==
dst_pack_size
||
src1_pack_size
==
1
,
""
);
const
PackType
<
Src
,
src0_pack_size
>*
src0
=
const
PackType
<
Src
,
src0_pack_size
>*
src0
=
reinterpret_cast
<
const
PackType
<
Src
,
src0_pack_size
>*>
(
params
.
src0
);
reinterpret_cast
<
const
PackType
<
Src
,
src0_pack_size
>*>
(
params
.
src0
);
const
PackType
<
Src
,
src1_pack_size
>*
src1
=
const
PackType
<
Src
,
src1_pack_size
>*
src1
=
reinterpret_cast
<
const
PackType
<
Src
,
src1_pack_size
>*>
(
params
.
src1
);
reinterpret_cast
<
const
PackType
<
Src
,
src1_pack_size
>*>
(
params
.
src1
);
PackType
<
Dst
,
dst_pack_size
>*
dst
=
reinterpret_cast
<
PackType
<
Dst
,
dst_pack_size
>*>
(
params
.
dst
);
PackType
<
Dst
,
dst_pack_size
>*
dst
=
reinterpret_cast
<
PackType
<
Dst
,
dst_pack_size
>*>
(
params
.
dst
);
IndexType
src0_index
[
max_dims
];
IndexType
src0_index
[
max_dims
];
IndexType
src1_index
[
max_dims
];
IndexType
src1_index
[
max_dims
];
IndexType
dst_index
[
max_dims
];
IndexType
dst_index
[
max_dims
];
size_t
num_dims
=
params
.
num_dims
;
size_t
num_dims
=
params
.
num_dims
;
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
offset
,
params
.
count
)
{
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
offset
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
params
.
dst_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
max_dims
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_dims
;
++
i
)
{
if
(
i
<
num_dims
)
{
if
(
i
<
num_dims
)
{
src0_index
[
i
]
=
params
.
src0_index_mask
[
i
]
*
dst_index
[
i
];
src0_index
[
i
]
=
params
.
src0_index_mask
[
i
]
*
dst_index
[
i
];
src1_index
[
i
]
=
params
.
src1_index_mask
[
i
]
*
dst_index
[
i
];
src1_index
[
i
]
=
params
.
src1_index_mask
[
i
]
*
dst_index
[
i
];
}
else
{
}
else
{
src0_index
[
i
]
=
0
;
src0_index
[
i
]
=
0
;
src1_index
[
i
]
=
0
;
src1_index
[
i
]
=
0
;
}
}
}
}
const
IndexType
src0_offset
=
params
.
src0_index_helper
.
NdIndexToOffset
(
src0_index
,
num_dims
);
const
IndexType
src0_offset
=
params
.
src0_index_helper
.
NdIndexToOffset
(
src0_index
,
num_dims
);
const
IndexType
src1_offset
=
params
.
src1_index_helper
.
NdIndexToOffset
(
src1_index
,
num_dims
);
const
IndexType
src1_offset
=
params
.
src1_index_helper
.
NdIndexToOffset
(
src1_index
,
num_dims
);
Pack
<
Src
,
src0_pack_size
>
src0_pack
;
Pack
<
Src
,
src0_pack_size
>
src0_pack
;
src0_pack
.
storage
=
src0
[
src0_offset
];
src0_pack
.
storage
=
src0
[
src0_offset
];
Pack
<
Src
,
src1_pack_size
>
src1_pack
;
Pack
<
Src
,
src1_pack_size
>
src1_pack
;
src1_pack
.
storage
=
src1
[
src1_offset
];
src1_pack
.
storage
=
src1
[
src1_offset
];
Pack
<
Dst
,
dst_pack_size
>
dst_pack
;
Pack
<
Dst
,
dst_pack_size
>
dst_pack
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
(
params
.
attr0
,
params
.
attr1
);
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
(
params
.
attr0
,
params
.
attr1
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
dst_pack_size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
dst_pack_size
;
++
j
)
{
const
Src
src0_val
=
const
Src
src0_val
=
(
src0_pack_size
==
dst_pack_size
)
?
src0_pack
.
elem
[
j
]
:
src0_pack
.
elem
[
0
];
(
src0_pack_size
==
dst_pack_size
)
?
src0_pack
.
elem
[
j
]
:
src0_pack
.
elem
[
0
];
const
Src
src1_val
=
const
Src
src1_val
=
(
src1_pack_size
==
dst_pack_size
)
?
src1_pack
.
elem
[
j
]
:
src1_pack
.
elem
[
0
];
(
src1_pack_size
==
dst_pack_size
)
?
src1_pack
.
elem
[
j
]
:
src1_pack
.
elem
[
0
];
dst_pack
.
elem
[
j
]
=
functor
(
src0_val
,
src1_val
);
dst_pack
.
elem
[
j
]
=
functor
(
src0_val
,
src1_val
);
}
}
dst
[
offset
]
=
dst_pack
.
storage
;
dst
[
offset
]
=
dst_pack
.
storage
;
}
}
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
size_t
src1_pack_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
int
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
void
LaunchKernel
(
Stream
*
stream
,
int
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
;
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
params
.
src0_index_mask
[
i
]
=
(
src0_dims
[
i
]
==
1
)
?
0
:
1
;
params
.
src0_index_mask
[
i
]
=
(
src0_dims
[
i
]
==
1
)
?
0
:
1
;
params
.
src1_index_mask
[
i
]
=
(
src1_dims
[
i
]
==
1
)
?
0
:
1
;
params
.
src1_index_mask
[
i
]
=
(
src1_dims
[
i
]
==
1
)
?
0
:
1
;
}
}
params
.
src0_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src0_dims
,
num_dims
);
params
.
src0_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src0_dims
,
num_dims
);
params
.
src1_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src1_dims
,
num_dims
);
params
.
src1_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src1_dims
,
num_dims
);
params
.
dst_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
dst_dims
,
num_dims
);
params
.
dst_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
dst_dims
,
num_dims
);
params
.
num_dims
=
num_dims
;
params
.
num_dims
=
num_dims
;
params
.
src0
=
src0
;
params
.
src0
=
src0
;
params
.
src1
=
src1
;
params
.
src1
=
src1
;
params
.
dst
=
dst
;
params
.
dst
=
dst
;
params
.
count
=
static_cast
<
IndexType
>
(
count
);
params
.
count
=
static_cast
<
IndexType
>
(
count
);
params
.
attr0
=
attr0
;
params
.
attr0
=
attr0
;
params
.
attr1
=
attr1
;
params
.
attr1
=
attr1
;
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
BroadcastElementwiseBinaryGpu
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
IndexType
>
BroadcastElementwiseBinaryGpu
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
->
cuda_stream
()
>>>
(
params
);
cuda_stream
->
cuda_stream
()
>>>
(
params
);
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
>
size_t
src1_pack_size
>
void
DispatchIndexType
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
void
DispatchIndexType
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
size_t
count
=
GetElementCount
(
num_dims
,
dst_dims
);
size_t
count
=
GetElementCount
(
num_dims
,
dst_dims
);
if
(
count
<
GetMaxVal
<
int32_t
>
())
{
if
(
count
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int32_t
>
(
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int32_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
else
{
}
else
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int64_t
>
(
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int64_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
}
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
>
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
>
void
DispatchPackSize
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
void
DispatchPackSize
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
nullptr
;
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
1
)
{
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
1
>
;
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
1
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
4
)
{
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
4
>
;
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
4
>
;
}
else
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
4
)
{
}
else
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
4
>
;
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
4
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
1
)
{
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
1
>
;
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
1
>
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
func
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
func
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
DispatchNumDims
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
void
DispatchNumDims
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*src0_pack_size*/
,
size_t
/*src1_pack_size*/
,
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*src0_pack_size*/
,
size_t
/*src1_pack_size*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
CHECK_NE
(
num_dims
,
1
);
CHECK_NE
(
num_dims
,
1
);
if
(
num_dims
==
2
)
{
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
2
>
;
func
=
DispatchPackSize
<
op
,
T
,
R
,
2
>
;
}
else
if
(
num_dims
==
3
)
{
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
3
>
;
func
=
DispatchPackSize
<
op
,
T
,
R
,
3
>
;
}
else
if
(
num_dims
==
4
)
{
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
4
>
;
func
=
DispatchPackSize
<
op
,
T
,
R
,
4
>
;
}
else
if
(
num_dims
<=
8
)
{
}
else
if
(
num_dims
<=
8
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
8
>
;
func
=
DispatchPackSize
<
op
,
T
,
R
,
8
>
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
func
(
stream
,
src0_pack_size
,
src1_pack_size
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
func
(
stream
,
src0_pack_size
,
src1_pack_size
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
dst
,
attr0
,
attr1
);
}
}
template
<
size_t
max_pack_size
,
typename
T
,
typename
R
>
template
<
size_t
max_pack_size
,
typename
T
,
typename
R
>
size_t
GetPackSize
(
size_t
num_src_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
size_t
GetPackSize
(
size_t
num_src_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
{
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
{
static_assert
(
max_pack_size
>
0
&&
(
max_pack_size
&
(
max_pack_size
-
1
))
==
0
,
""
);
static_assert
(
max_pack_size
>
0
&&
(
max_pack_size
&
(
max_pack_size
-
1
))
==
0
,
""
);
CHECK
(
src0_dims
[
num_src_dims
-
1
]
!=
1
||
src1_dims
[
num_src_dims
-
1
]
!=
1
);
CHECK
(
src0_dims
[
num_src_dims
-
1
]
!=
1
||
src1_dims
[
num_src_dims
-
1
]
!=
1
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
for
(
size_t
pack_size
=
max_pack_size
;
pack_size
>
2
;
pack_size
/=
2
)
{
for
(
size_t
pack_size
=
max_pack_size
;
pack_size
>
2
;
pack_size
/=
2
)
{
bool
is_src0_supported
=
(
src0_dims
[
num_src_dims
-
1
]
==
1
)
bool
is_src0_supported
=
(
src0_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src0_dims
,
src0
);
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src0_dims
,
src0
);
bool
is_src1_supported
=
(
src1_dims
[
num_src_dims
-
1
]
==
1
)
bool
is_src1_supported
=
(
src1_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src1_dims
,
src1
);
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src1_dims
,
src1
);
if
(
is_src0_supported
&&
is_src1_supported
&&
(
dst_ptr
%
(
pack_size
*
sizeof
(
R
)))
==
0
)
{
if
(
is_src0_supported
&&
is_src1_supported
&&
(
dst_ptr
%
(
pack_size
*
sizeof
(
R
)))
==
0
)
{
return
pack_size
;
return
pack_size
;
}
}
}
}
return
1
;
return
1
;
}
}
constexpr
size_t
kMaxPackSize
=
4
;
constexpr
size_t
kMaxPackSize
=
4
;
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
simplified_num_dims
,
int64_t
*
simplified_src0_dims
,
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
simplified_num_dims
,
int64_t
*
simplified_src0_dims
,
const
void
*
src0
,
int64_t
*
simplified_src1_dims
,
const
void
*
src1
,
const
void
*
src0
,
int64_t
*
simplified_src1_dims
,
const
void
*
src1
,
int64_t
*
simplified_dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
int64_t
*
simplified_dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
CHECK_LE
(
simplified_num_dims
,
kMaxNumDims
);
CHECK_LE
(
simplified_num_dims
,
kMaxNumDims
);
size_t
pack_size
=
GetPackSize
<
kMaxPackSize
,
T
,
R
>
(
simplified_num_dims
,
simplified_src0_dims
,
size_t
pack_size
=
GetPackSize
<
kMaxPackSize
,
T
,
R
>
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
);
src0
,
simplified_src1_dims
,
src1
,
dst
);
size_t
src0_pack_size
=
1
;
size_t
src0_pack_size
=
1
;
size_t
src1_pack_size
=
1
;
size_t
src1_pack_size
=
1
;
if
(
simplified_src0_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
if
(
simplified_src0_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src0_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
simplified_src0_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src0_pack_size
=
pack_size
;
src0_pack_size
=
pack_size
;
}
}
if
(
simplified_src1_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
if
(
simplified_src1_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src1_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
simplified_src1_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src1_pack_size
=
pack_size
;
src1_pack_size
=
pack_size
;
}
}
simplified_dst_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
simplified_dst_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
DispatchNumDims
<
op
,
T
,
R
>
(
stream
,
src0_pack_size
,
src1_pack_size
,
simplified_num_dims
,
DispatchNumDims
<
op
,
T
,
R
>
(
stream
,
src0_pack_size
,
src1_pack_size
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
}
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarFunctor
{
struct
BinaryLhsScalarFunctor
{
__host__
__device__
BinaryLhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
__host__
__device__
BinaryLhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
scalar
,
src
);
}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
scalar
,
src
);
}
const
Src
scalar
;
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarFunctor
{
struct
BinaryRhsScalarFunctor
{
__host__
__device__
BinaryRhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
__host__
__device__
BinaryRhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
src
,
scalar
);
}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
src
,
scalar
);
}
const
Src
scalar
;
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarPtrFunctorFactory
{
struct
BinaryLhsScalarPtrFunctorFactory
{
__host__
__device__
BinaryLhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
__host__
__device__
BinaryLhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
__device__
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
return
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
}
const
Src
*
scalar_ptr
;
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
Scalar
attr0
,
attr1
;
};
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarPtrFunctorFactory
{
struct
BinaryRhsScalarPtrFunctorFactory
{
__host__
__device__
explicit
BinaryRhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
__host__
__device__
explicit
BinaryRhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
__device__
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
return
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
}
const
Src
*
scalar_ptr
;
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
Scalar
attr0
,
attr1
;
};
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
void
DispatchLaunch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
Src
*
src0
,
void
DispatchLaunch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
Src
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
Src
*
src1
,
Dst
*
dst
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
Src
*
src1
,
Dst
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
Scalar
attr0
,
Scalar
attr1
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
size_t
simplified_num_dims
=
0
;
size_t
simplified_num_dims
=
0
;
int64_t
simplified_src0_dims
[
kMaxNumDims
];
int64_t
simplified_src0_dims
[
kMaxNumDims
];
int64_t
simplified_src1_dims
[
kMaxNumDims
];
int64_t
simplified_src1_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
simplified_src1_dims
,
simplified_dst_dims
);
simplified_src1_dims
,
simplified_dst_dims
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
);
simplified_dst_dims
,
dst
);
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
simplified_src1_dims
))
{
simplified_src1_dims
))
{
const
int64_t
elem_cnt
=
GetElementCount
(
simplified_num_dims
,
simplified_src0_dims
);
const
int64_t
elem_cnt
=
GetElementCount
(
simplified_num_dims
,
simplified_src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
(
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
),
elem_cnt
,
dst
,
src0
,
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
),
elem_cnt
,
dst
,
src0
,
src1
,
cuda_stream
->
cuda_stream
())));
src1
,
cuda_stream
->
cuda_stream
())));
}
else
{
}
else
{
if
(
simplified_num_dims
==
1
&&
simplified_src0_dims
[
0
]
==
1
)
{
if
(
simplified_num_dims
==
1
&&
simplified_src0_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryLhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src0
,
attr0
,
attr1
),
BinaryLhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src0
,
attr0
,
attr1
),
simplified_src1_dims
[
0
],
dst
,
src1
,
cuda_stream
->
cuda_stream
())));
simplified_src1_dims
[
0
],
dst
,
src1
,
cuda_stream
->
cuda_stream
())));
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src1_dims
[
0
]
==
1
)
{
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src1_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryRhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src1
,
attr0
,
attr1
),
BinaryRhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src1
,
attr0
,
attr1
),
simplified_src0_dims
[
0
],
dst
,
src0
,
cuda_stream
->
cuda_stream
())));
simplified_src0_dims
[
0
],
dst
,
src0
,
cuda_stream
->
cuda_stream
())));
}
else
{
}
else
{
LaunchWithSimplified
<
binary_op
,
Src
,
Dst
>
(
stream
,
simplified_num_dims
,
simplified_src0_dims
,
LaunchWithSimplified
<
binary_op
,
Src
,
Dst
>
(
stream
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
return
value
.
Value
<
T
>
();
}
}
template
<
>
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
}
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// template<>
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// }
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
class
BroadcastElementwiseBinaryImpl
:
public
BroadcastElementwiseBinary
{
class
BroadcastElementwiseBinaryImpl
:
public
BroadcastElementwiseBinary
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryImpl
);
BroadcastElementwiseBinaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
BroadcastElementwiseBinaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
BroadcastElementwiseBinaryImpl
()
override
=
default
;
~
BroadcastElementwiseBinaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
Scalar
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
void
Launch
(
Stream
*
stream
,
Scalar
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
const
void
*
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src1_dims
,
src1_dims
);
const
size_t
elem_cnt
=
GetElementCount
(
num_src1_dims
,
src1_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src0
),
attr0
,
attr1
),
elem_cnt
,
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src0
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src1
),
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src1
),
cuda_stream
->
cuda_stream
())));
cuda_stream
->
cuda_stream
())));
}
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
Scalar
src1
,
void
*
dst
)
override
{
Scalar
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src0_dims
,
src0_dims
);
const
size_t
elem_cnt
=
GetElementCount
(
num_src0_dims
,
src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src1
),
attr0
,
attr1
),
elem_cnt
,
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src1
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src0
),
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src0
),
cuda_stream
->
cuda_stream
())));
cuda_stream
->
cuda_stream
())));
}
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
void
*
dst
)
override
{
DispatchLaunch
<
binary_op
,
Src
,
Dst
>
(
DispatchLaunch
<
binary_op
,
Src
,
Dst
>
(
stream
,
num_src0_dims
,
src0_dims
,
reinterpret_cast
<
const
Src
*>
(
src0
),
num_src1_dims
,
stream
,
num_src0_dims
,
src0_dims
,
reinterpret_cast
<
const
Src
*>
(
src0
),
num_src1_dims
,
src1_dims
,
reinterpret_cast
<
const
Src
*>
(
src1
),
reinterpret_cast
<
Dst
*>
(
dst
),
attr0
,
attr1
);
src1_dims
,
reinterpret_cast
<
const
Src
*>
(
src1
),
reinterpret_cast
<
Dst
*>
(
dst
),
attr0
,
attr1
);
}
}
private:
private:
Scalar
attr0
,
attr1
;
Scalar
attr0
,
attr1
;
};
};
}
// namespace
}
// namespace
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
)
{
Scalar
attr1
)
{
return
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
return
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
new
BroadcastElementwiseBinaryImpl
<
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
new
BroadcastElementwiseBinaryImpl
<
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
}
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_activation_grad.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair
)
\
data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
);
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_comparision.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op
,
src_data_type_pair
,
dst_data_type_pair
)
\
binary_op, src_data_type_pair, dst_data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY
,
BINARY_COMPARISION_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
BINARY_COMPARISION_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_logical.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair
)
\
dst_data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifdef WITH_ROCM
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_matmul
{
namespace
broadcast_matmul
{
namespace
internal
{
namespace
internal
{
namespace
{
namespace
{
constexpr
size_t
kMaxNumDims
=
8
;
constexpr
size_t
kMaxNumDims
=
8
;
Optional
<
hipblasDatatype_t
>
OptCudaDataType
(
DataType
data_type
)
{
Optional
<
hipblasDatatype_t
>
OptCudaDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
default:
return
NullOpt
;
default:
return
NullOpt
;
}
}
}
}
hipblasDatatype_t
GetCudaDataType
(
DataType
data_type
)
{
hipblasDatatype_t
GetCudaDataType
(
DataType
data_type
)
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
CHECK
(
cuda_data_type
.
has_value
());
CHECK
(
cuda_data_type
.
has_value
());
return
cuda_data_type
.
value_or
(
HIPBLAS_R_32F
);
return
cuda_data_type
.
value_or
(
HIPBLAS_R_32F
);
}
}
union
CublasScalarParameter
{
union
CublasScalarParameter
{
double
d
;
double
d
;
float
s
;
float
s
;
};
};
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
hipblasDatatype_t
compute_type
)
{
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
hipblasDatatype_t
compute_type
)
{
CublasScalarParameter
sp
{};
CublasScalarParameter
sp
{};
if
(
compute_type
==
HIPBLAS_R_64F
)
{
if
(
compute_type
==
HIPBLAS_R_64F
)
{
sp
.
d
=
scalar
.
Value
<
double
>
();
sp
.
d
=
scalar
.
Value
<
double
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_32F
)
{
}
else
if
(
compute_type
==
HIPBLAS_R_32F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_16F
)
{
}
else
if
(
compute_type
==
HIPBLAS_R_16F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
return
sp
;
return
sp
;
}
}
hipblasDatatype_t
GetComputeType
(
DataType
data_type
)
{
hipblasDatatype_t
GetComputeType
(
DataType
data_type
)
{
switch
(
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
default:
UNIMPLEMENTED
();
return
HIPBLAS_R_32F
;
default:
UNIMPLEMENTED
();
return
HIPBLAS_R_32F
;
}
}
}
}
void
LaunchBroadcastMatmul
(
Stream
*
stream
,
DataType
data_type
,
BlasTransposeType
transpose_a
,
void
LaunchBroadcastMatmul
(
Stream
*
stream
,
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
int64_t
num_batch_dims
,
BlasTransposeType
transpose_b
,
int64_t
num_batch_dims
,
const
int64_t
*
broadcast_batch_dims
,
const
int64_t
*
a_batch_dims
,
const
int64_t
*
broadcast_batch_dims
,
const
int64_t
*
a_batch_dims
,
const
int64_t
*
b_batch_dims
,
const
int64_t
*
c_batch_dims
,
int64_t
m
,
const
int64_t
*
b_batch_dims
,
const
int64_t
*
c_batch_dims
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
Scalar
alpha
,
const
void
*
a
,
const
void
*
b
,
int64_t
n
,
int64_t
k
,
Scalar
alpha
,
const
void
*
a
,
const
void
*
b
,
Scalar
beta
,
void
*
c
)
{
Scalar
beta
,
void
*
c
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
auto
cuda_data_type
=
GetCudaDataType
(
data_type
);
const
auto
cuda_data_type
=
GetCudaDataType
(
data_type
);
const
auto
compute_type
=
GetComputeType
(
data_type
);
const
auto
compute_type
=
GetComputeType
(
data_type
);
const
auto
sp_alpha
=
GetCublasScalarParameter
(
alpha
,
compute_type
);
const
auto
sp_alpha
=
GetCublasScalarParameter
(
alpha
,
compute_type
);
__half
h_alpha
=
0
;
__half
h_alpha
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_alpha
=
__float2half
(
sp_alpha
.
s
);
h_alpha
=
__float2half
(
sp_alpha
.
s
);
}
}
const
auto
GetCublasOperation
=
[](
BlasTransposeType
transpose_type
)
{
const
auto
GetCublasOperation
=
[](
BlasTransposeType
transpose_type
)
{
if
(
transpose_type
==
BlasTransposeType
::
N
)
{
if
(
transpose_type
==
BlasTransposeType
::
N
)
{
return
HIPBLAS_OP_N
;
return
HIPBLAS_OP_N
;
}
else
if
(
transpose_type
==
BlasTransposeType
::
T
)
{
}
else
if
(
transpose_type
==
BlasTransposeType
::
T
)
{
return
HIPBLAS_OP_T
;
return
HIPBLAS_OP_T
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
return
HIPBLAS_OP_N
;
return
HIPBLAS_OP_N
;
}
}
};
};
const
hipblasOperation_t
cublas_trans_a
=
GetCublasOperation
(
transpose_b
);
const
hipblasOperation_t
cublas_trans_a
=
GetCublasOperation
(
transpose_b
);
const
hipblasOperation_t
cublas_trans_b
=
GetCublasOperation
(
transpose_a
);
const
hipblasOperation_t
cublas_trans_b
=
GetCublasOperation
(
transpose_a
);
const
int
cublas_m
=
n
;
const
int
cublas_m
=
n
;
const
int
cublas_n
=
m
;
const
int
cublas_n
=
m
;
const
int
cublas_k
=
k
;
const
int
cublas_k
=
k
;
int
cublas_lda
=
0
;
int
cublas_lda
=
0
;
if
(
transpose_b
==
BlasTransposeType
::
N
)
{
if
(
transpose_b
==
BlasTransposeType
::
N
)
{
cublas_lda
=
n
;
cublas_lda
=
n
;
}
else
if
(
transpose_b
==
BlasTransposeType
::
T
)
{
}
else
if
(
transpose_b
==
BlasTransposeType
::
T
)
{
cublas_lda
=
k
;
cublas_lda
=
k
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
int
cublas_ldb
=
0
;
int
cublas_ldb
=
0
;
if
(
transpose_a
==
BlasTransposeType
::
N
)
{
if
(
transpose_a
==
BlasTransposeType
::
N
)
{
cublas_ldb
=
k
;
cublas_ldb
=
k
;
}
else
if
(
transpose_a
==
BlasTransposeType
::
T
)
{
}
else
if
(
transpose_a
==
BlasTransposeType
::
T
)
{
cublas_ldb
=
m
;
cublas_ldb
=
m
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
const
int
cublas_ldc
=
n
;
const
int
cublas_ldc
=
n
;
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) {
// if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000
// #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else
// #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000
// #endif // CUDA_VERSION < 11000
// }
// }
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t
algo
=
HIPBLAS_GEMM_DEFAULT
;
hipblasGemmAlgo_t
algo
=
HIPBLAS_GEMM_DEFAULT
;
// #else
// #else
// hipblasGemmAlgo_t algo =
// hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif
// #endif
if
(
num_batch_dims
==
1
&&
c_batch_dims
[
0
]
!=
1
)
{
if
(
num_batch_dims
==
1
&&
c_batch_dims
[
0
]
!=
1
)
{
const
void
*
cublas_a
=
b
;
const
void
*
cublas_a
=
b
;
const
void
*
cublas_b
=
a
;
const
void
*
cublas_b
=
a
;
void
*
cublas_c
=
c
;
void
*
cublas_c
=
c
;
const
int64_t
a_batch_count
=
a_batch_dims
[
0
];
const
int64_t
a_batch_count
=
a_batch_dims
[
0
];
const
int64_t
b_batch_count
=
b_batch_dims
[
0
];
const
int64_t
b_batch_count
=
b_batch_dims
[
0
];
CHECK
(
a_batch_count
==
1
||
b_batch_count
==
1
||
a_batch_count
==
b_batch_count
);
CHECK
(
a_batch_count
==
1
||
b_batch_count
==
1
||
a_batch_count
==
b_batch_count
);
CHECK_GT
(
a_batch_count
,
0
);
CHECK_GT
(
a_batch_count
,
0
);
CHECK_GT
(
b_batch_count
,
0
);
CHECK_GT
(
b_batch_count
,
0
);
const
int
batch_count
=
std
::
max
(
a_batch_count
,
b_batch_count
);
const
int
batch_count
=
std
::
max
(
a_batch_count
,
b_batch_count
);
const
long
long
int
cublas_stride_a
=
b_batch_count
==
1
?
0
:
cublas_m
*
cublas_k
;
const
long
long
int
cublas_stride_a
=
b_batch_count
==
1
?
0
:
cublas_m
*
cublas_k
;
const
long
long
int
cublas_stride_b
=
a_batch_count
==
1
?
0
:
cublas_k
*
cublas_n
;
const
long
long
int
cublas_stride_b
=
a_batch_count
==
1
?
0
:
cublas_k
*
cublas_n
;
const
long
long
int
cublas_stride_c
=
cublas_m
*
cublas_n
;
const
long
long
int
cublas_stride_c
=
cublas_m
*
cublas_n
;
const
auto
sp_beta
=
GetCublasScalarParameter
(
beta
,
compute_type
);
const
auto
sp_beta
=
GetCublasScalarParameter
(
beta
,
compute_type
);
__half
h_beta
=
0
;
__half
h_beta
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_ldb
,
cublas_stride_b
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
else
{
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_ldb
,
cublas_stride_b
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
}
}
else
{
}
else
{
auto
func
=
[
&
](
const
void
*
batch_a
,
const
void
*
batch_b
,
void
*
batch_c
,
Scalar
batch_beta
)
{
auto
func
=
[
&
](
const
void
*
batch_a
,
const
void
*
batch_b
,
void
*
batch_c
,
Scalar
batch_beta
)
{
const
auto
sp_beta
=
GetCublasScalarParameter
(
batch_beta
,
compute_type
);
const
auto
sp_beta
=
GetCublasScalarParameter
(
batch_beta
,
compute_type
);
__half
h_beta
=
0
;
__half
h_beta
=
0
;
const
void
*
cublas_a
=
batch_b
;
const
void
*
cublas_a
=
batch_b
;
const
void
*
cublas_b
=
batch_a
;
const
void
*
cublas_b
=
batch_a
;
void
*
cublas_c
=
batch_c
;
void
*
cublas_c
=
batch_c
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
cublas_ldb
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
else
{
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
cublas_ldb
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
}
};
};
ForEachMatmul
<
kMaxNumDims
>
(
data_type
,
m
,
n
,
k
,
beta
,
num_batch_dims
,
broadcast_batch_dims
,
ForEachMatmul
<
kMaxNumDims
>
(
data_type
,
m
,
n
,
k
,
beta
,
num_batch_dims
,
broadcast_batch_dims
,
a_batch_dims
,
b_batch_dims
,
c_batch_dims
,
a
,
b
,
c
,
func
);
a_batch_dims
,
b_batch_dims
,
c_batch_dims
,
a
,
b
,
c
,
func
);
}
}
}
}
class
BroadcastMatmulFactoryImpl
:
public
BroadcastMatmulFactory
{
class
BroadcastMatmulFactoryImpl
:
public
BroadcastMatmulFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastMatmulFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastMatmulFactoryImpl
);
BroadcastMatmulFactoryImpl
()
=
default
;
BroadcastMatmulFactoryImpl
()
=
default
;
~
BroadcastMatmulFactoryImpl
()
override
=
default
;
~
BroadcastMatmulFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastMatmul
>
New
(
DataType
data_type
,
BlasTransposeType
transpose_a
,
std
::
unique_ptr
<
BroadcastMatmul
>
New
(
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
BlasTransposeType
transpose_b
,
size_t
max_num_dims
)
override
{
size_t
max_num_dims
)
override
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
if
(
max_num_dims
<=
kMaxNumDims
&&
cuda_data_type
.
has_value
())
{
if
(
max_num_dims
<=
kMaxNumDims
&&
cuda_data_type
.
has_value
())
{
return
std
::
make_unique
<
BroadcastMatmulImpl
<
kMaxNumDims
>>
(
data_type
,
transpose_a
,
return
std
::
make_unique
<
BroadcastMatmulImpl
<
kMaxNumDims
>>
(
data_type
,
transpose_a
,
transpose_b
);
transpose_b
);
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastMatmulFactory
,
BroadcastMatmulFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastMatmulFactory
,
BroadcastMatmulFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace internal
}
// namespace internal
}
// namespace broadcast_matmul
}
// namespace broadcast_matmul
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
#endif // WITH_ROCM
#endif // WITH_ROCM
oneflow/core/ep/rocm/primitive/cast.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
template
<
typename
To
,
typename
From
,
typename
=
void
>
template
<
typename
To
,
typename
From
,
typename
=
void
>
struct
CastFunctor
{
struct
CastFunctor
{
__device__
To
operator
()(
From
from
)
const
{
return
static_cast
<
To
>
(
from
);
}
__device__
To
operator
()(
From
from
)
const
{
return
static_cast
<
To
>
(
from
);
}
};
};
template
<
typename
To
>
template
<
typename
To
>
struct
CastFunctor
<
To
,
half
,
typename
std
::
enable_if
<!
std
::
is_same
<
To
,
half
>::
value
>::
type
>
{
struct
CastFunctor
<
To
,
half
,
typename
std
::
enable_if
<!
std
::
is_same
<
To
,
half
>::
value
>::
type
>
{
__device__
To
operator
()(
half
from
)
const
{
return
static_cast
<
To
>
(
static_cast
<
float
>
(
from
));
}
__device__
To
operator
()(
half
from
)
const
{
return
static_cast
<
To
>
(
static_cast
<
float
>
(
from
));
}
__device__
void
Apply2
(
To
*
to
,
const
half
*
from
)
const
{
__device__
void
Apply2
(
To
*
to
,
const
half
*
from
)
const
{
const
float2
f2
=
__half22float2
(
*
reinterpret_cast
<
const
half2
*>
(
from
));
const
float2
f2
=
__half22float2
(
*
reinterpret_cast
<
const
half2
*>
(
from
));
to
[
0
]
=
static_cast
<
To
>
(
f2
.
x
);
to
[
0
]
=
static_cast
<
To
>
(
f2
.
x
);
to
[
1
]
=
static_cast
<
To
>
(
f2
.
y
);
to
[
1
]
=
static_cast
<
To
>
(
f2
.
y
);
}
}
};
};
template
<
typename
From
>
template
<
typename
From
>
struct
CastFunctor
<
half
,
From
,
typename
std
::
enable_if
<!
std
::
is_same
<
From
,
half
>::
value
>::
type
>
{
struct
CastFunctor
<
half
,
From
,
typename
std
::
enable_if
<!
std
::
is_same
<
From
,
half
>::
value
>::
type
>
{
__device__
half
operator
()(
From
from
)
const
{
__device__
half
operator
()(
From
from
)
const
{
return
static_cast
<
half
>
(
static_cast
<
float
>
(
from
));
return
static_cast
<
half
>
(
static_cast
<
float
>
(
from
));
}
}
__device__
void
Apply2
(
half
*
to
,
const
From
*
from
)
const
{
__device__
void
Apply2
(
half
*
to
,
const
From
*
from
)
const
{
float2
f2
;
float2
f2
;
f2
.
x
=
static_cast
<
float
>
(
from
[
0
]);
f2
.
x
=
static_cast
<
float
>
(
from
[
0
]);
f2
.
y
=
static_cast
<
float
>
(
from
[
1
]);
f2
.
y
=
static_cast
<
float
>
(
from
[
1
]);
*
reinterpret_cast
<
half2
*>
(
to
)
=
__float22half2_rn
(
f2
);
*
reinterpret_cast
<
half2
*>
(
to
)
=
__float22half2_rn
(
f2
);
}
}
};
};
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// template<typename To>
// template<typename To>
// struct CastFunctor<To, nv_bfloat16,
// struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> {
// || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const {
// __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from));
// return static_cast<To>(static_cast<float>(from));
// }
// }
// };
// };
// template<typename From>
// template<typename From>
// struct CastFunctor<nv_bfloat16, From,
// struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> {
// || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const {
// __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// }
// }
// };
// };
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
template
<
typename
From
,
typename
To
>
template
<
typename
From
,
typename
To
>
class
CastImpl
:
public
Cast
{
class
CastImpl
:
public
Cast
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
CastImpl
);
explicit
CastImpl
()
=
default
;
explicit
CastImpl
()
=
default
;
~
CastImpl
()
override
=
default
;
~
CastImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
from
,
void
*
to
,
size_t
count
)
override
{
void
Launch
(
Stream
*
stream
,
const
void
*
from
,
void
*
to
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
CastFunctor
<
To
,
From
>
,
To
,
From
>
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
CastFunctor
<
To
,
From
>
,
To
,
From
>
(
CastFunctor
<
To
,
From
>
(),
count
,
reinterpret_cast
<
To
*>
(
to
),
CastFunctor
<
To
,
From
>
(),
count
,
reinterpret_cast
<
To
*>
(
to
),
reinterpret_cast
<
const
From
*>
(
from
),
cuda_stream
->
cuda_stream
())));
reinterpret_cast
<
const
From
*>
(
from
),
cuda_stream
->
cuda_stream
())));
}
}
};
};
template
<
typename
From
,
typename
To
>
template
<
typename
From
,
typename
To
>
std
::
unique_ptr
<
Cast
>
NewCast
()
{
std
::
unique_ptr
<
Cast
>
NewCast
()
{
return
std
::
unique_ptr
<
Cast
>
(
new
CastImpl
<
From
,
To
>
());
return
std
::
unique_ptr
<
Cast
>
(
new
CastImpl
<
From
,
To
>
());
}
}
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
\
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ
\
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_UINT32_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_UINT64_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class
CastFactoryImpl
:
public
CastFactory
{
class
CastFactoryImpl
:
public
CastFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
CastFactoryImpl
);
CastFactoryImpl
()
=
default
;
CastFactoryImpl
()
=
default
;
~
CastFactoryImpl
()
override
=
default
;
~
CastFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Cast
>
New
(
DataType
from
,
DataType
to
)
override
{
std
::
unique_ptr
<
Cast
>
New
(
DataType
from
,
DataType
to
)
override
{
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{
std
::
make_pair
(
OF_PP_PAIR_SECOND
(
from_pair
),
OF_PP_PAIR_SECOND
(
to_pair
)),
\
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast
<
OF_PP_PAIR_FIRST
(
from_pair
),
OF_PP_PAIR_FIRST
(
to_pair
)
>
},
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static
const
std
::
map
<
std
::
pair
<
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
Cast
>
()
>>
static
const
std
::
map
<
std
::
pair
<
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
Cast
>
()
>>
new_cast_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
new_cast_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_CAST_ENTRY
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
)};
MAKE_NEW_CAST_ENTRY
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
)};
#undef MAKE_NEW_CAST_ENTRY
#undef MAKE_NEW_CAST_ENTRY
const
auto
it
=
new_cast_handle
.
find
(
std
::
make_pair
(
from
,
to
));
const
auto
it
=
new_cast_handle
.
find
(
std
::
make_pair
(
from
,
to
));
if
(
it
!=
new_cast_handle
.
end
())
{
if
(
it
!=
new_cast_handle
.
end
())
{
return
it
->
second
();
return
it
->
second
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CastFactory
,
CastFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CastFactory
,
CastFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
__global__
void
ConstantPadKernel
(
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
__global__
void
ConstantPadKernel
(
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
)
{
StorageType
packed_pad_val
)
{
const
StorageType
*
src
=
reinterpret_cast
<
const
StorageType
*>
(
params
.
src
);
const
StorageType
*
src
=
reinterpret_cast
<
const
StorageType
*>
(
params
.
src
);
StorageType
*
dst
=
reinterpret_cast
<
StorageType
*>
(
params
.
dst
);
StorageType
*
dst
=
reinterpret_cast
<
StorageType
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
linear_index
,
params
.
elem_cnt
)
{
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
linear_index
,
params
.
elem_cnt
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
linear_index
,
dst_index
);
params
.
dst_index_helper
.
OffsetToNdIndex
(
linear_index
,
dst_index
);
bool
if_pad
=
false
;
bool
if_pad
=
false
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
if
(
dst_index
[
i
]
>=
params
.
valid_start
[
i
]
&&
dst_index
[
i
]
<
params
.
valid_end
[
i
])
{
if
(
dst_index
[
i
]
>=
params
.
valid_start
[
i
]
&&
dst_index
[
i
]
<
params
.
valid_end
[
i
])
{
src_index
[
i
]
=
dst_index
[
i
]
-
params
.
valid_start
[
i
];
src_index
[
i
]
=
dst_index
[
i
]
-
params
.
valid_start
[
i
];
}
else
{
}
else
{
if_pad
=
true
;
if_pad
=
true
;
break
;
break
;
}
}
}
}
StorageType
dst_val
=
packed_pad_val
;
StorageType
dst_val
=
packed_pad_val
;
if
(
!
if_pad
)
{
if
(
!
if_pad
)
{
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst_val
=
src
[
src_offset
];
dst_val
=
src
[
src_offset
];
}
}
dst
[
linear_index
]
=
dst_val
;
dst
[
linear_index
]
=
dst_val
;
}
}
}
}
template
<
>
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
}
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// template<>
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// }
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
void
LaunchKernel
(
Stream
*
stream
,
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
stream
->
As
<
CudaStream
>
()
->
LaunchKernelDefaultWaves
(
stream
->
As
<
CudaStream
>
()
->
LaunchKernelDefaultWaves
(
(
ConstantPadKernel
<
num_dims
,
IndexType
,
StorageType
>
),
elem_cnt
,
params
,
packed_pad_val
);
(
ConstantPadKernel
<
num_dims
,
IndexType
,
StorageType
>
),
elem_cnt
,
params
,
packed_pad_val
);
}
}
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
void
LaunchKernel
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
ConstantPadParams
<
num_dims
,
IndexType
>
params
;
ConstantPadParams
<
num_dims
,
IndexType
>
params
;
params
.
dst_index_helper
=
OffsetToIndexCalculator
<
IndexType
,
num_dims
>
(
dst_dims
);
params
.
dst_index_helper
=
OffsetToIndexCalculator
<
IndexType
,
num_dims
>
(
dst_dims
);
params
.
src_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
num_dims
>
(
src_dims
);
params
.
src_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
num_dims
>
(
src_dims
);
params
.
dst
=
dst
;
params
.
dst
=
dst
;
params
.
src
=
src
;
params
.
src
=
src
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
params
.
valid_start
[
i
]
=
padding_before
[
i
];
params
.
valid_start
[
i
]
=
padding_before
[
i
];
params
.
valid_end
[
i
]
=
dst_dims
[
i
]
-
padding_after
[
i
];
params
.
valid_end
[
i
]
=
dst_dims
[
i
]
-
padding_after
[
i
];
}
}
params
.
elem_cnt
=
elem_cnt
;
params
.
elem_cnt
=
elem_cnt
;
LaunchKernel
<
num_dims
,
IndexType
,
StorageType
>
(
stream
,
params
,
packed_pad_val
,
elem_cnt
);
LaunchKernel
<
num_dims
,
IndexType
,
StorageType
>
(
stream
,
params
,
packed_pad_val
,
elem_cnt
);
}
}
template
<
size_t
num_dims
,
typename
StorageType
>
template
<
size_t
num_dims
,
typename
StorageType
>
void
DispatchIndexType
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
void
DispatchIndexType
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
if
(
elem_cnt
<
GetMaxVal
<
int32_t
>
())
{
if
(
elem_cnt
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
num_dims
,
int32_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
LaunchKernel
<
num_dims
,
int32_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
elem_cnt
);
}
else
{
}
else
{
LaunchKernel
<
num_dims
,
int64_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
LaunchKernel
<
num_dims
,
int64_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
elem_cnt
);
}
}
}
}
template
<
size_t
num_dims
,
typename
T
>
template
<
size_t
num_dims
,
typename
T
>
void
DispatchPackSize
(
Stream
*
stream
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
void
DispatchPackSize
(
Stream
*
stream
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
T
pad_val
)
{
constexpr
int32_t
max_packsize
=
GetMaxPackSize
<
T
>
();
constexpr
int32_t
max_packsize
=
GetMaxPackSize
<
T
>
();
size_t
launch_pack_size
=
GetLaunchPackSize
<
max_packsize
>
(
num_dims
,
dst
,
dst_dims
,
src
,
src_dims
,
size_t
launch_pack_size
=
GetLaunchPackSize
<
max_packsize
>
(
num_dims
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
);
padding_before
,
padding_after
);
dst_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
dst_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
src_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
src_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_before
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_before
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_after
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_after
[
num_dims
-
1
]
/=
launch_pack_size
;
size_t
elem_cnt
=
1
;
size_t
elem_cnt
=
1
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
elem_cnt
*=
dst_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
elem_cnt
*=
dst_dims
[
i
];
}
if
(
launch_pack_size
==
1
)
{
if
(
launch_pack_size
==
1
)
{
Pack
<
T
,
1
>
packed_pad_val
(
pad_val
);
Pack
<
T
,
1
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
1
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
DispatchIndexType
<
num_dims
,
PackType
<
T
,
1
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
2
)
{
}
else
if
(
launch_pack_size
==
2
)
{
Pack
<
T
,
2
>
packed_pad_val
(
pad_val
);
Pack
<
T
,
2
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
2
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
DispatchIndexType
<
num_dims
,
PackType
<
T
,
2
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
4
)
{
}
else
if
(
launch_pack_size
==
4
)
{
Pack
<
T
,
4
>
packed_pad_val
(
pad_val
);
Pack
<
T
,
4
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
4
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
DispatchIndexType
<
num_dims
,
PackType
<
T
,
4
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
8
)
{
}
else
if
(
launch_pack_size
==
8
)
{
Pack
<
T
,
8
>
packed_pad_val
(
pad_val
);
Pack
<
T
,
8
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
8
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
DispatchIndexType
<
num_dims
,
PackType
<
T
,
8
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
16
)
{
}
else
if
(
launch_pack_size
==
16
)
{
Pack
<
T
,
16
>
packed_pad_val
(
pad_val
);
Pack
<
T
,
16
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
16
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
DispatchIndexType
<
num_dims
,
PackType
<
T
,
16
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
packed_pad_val
.
storage
,
elem_cnt
);
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
num_dims
,
void
*
dst
,
int64_t
*
dst_dims
,
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
num_dims
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
int64_t
*
padding_after
,
T
pad_val
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
void
*
/*dst*/
,
int64_t
*
/*dst_dims*/
,
const
void
*
/*src*/
,
void
(
*
func
)(
Stream
*
/*stream*/
,
void
*
/*dst*/
,
int64_t
*
/*dst_dims*/
,
const
void
*
/*src*/
,
int64_t
*
/*src_dims*/
,
int64_t
*
/*padding_before*/
,
int64_t
*
/*padding_after*/
,
T
)
=
int64_t
*
/*src_dims*/
,
int64_t
*
/*padding_before*/
,
int64_t
*
/*padding_after*/
,
T
)
=
nullptr
;
nullptr
;
if
(
num_dims
==
1
)
{
if
(
num_dims
==
1
)
{
func
=
DispatchPackSize
<
1
,
T
>
;
func
=
DispatchPackSize
<
1
,
T
>
;
}
else
if
(
num_dims
==
2
)
{
}
else
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
2
,
T
>
;
func
=
DispatchPackSize
<
2
,
T
>
;
}
else
if
(
num_dims
==
3
)
{
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
3
,
T
>
;
func
=
DispatchPackSize
<
3
,
T
>
;
}
else
if
(
num_dims
==
4
)
{
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
4
,
T
>
;
func
=
DispatchPackSize
<
4
,
T
>
;
}
else
if
(
num_dims
==
5
)
{
}
else
if
(
num_dims
==
5
)
{
func
=
DispatchPackSize
<
5
,
T
>
;
func
=
DispatchPackSize
<
5
,
T
>
;
}
else
if
(
num_dims
==
6
)
{
}
else
if
(
num_dims
==
6
)
{
func
=
DispatchPackSize
<
6
,
T
>
;
func
=
DispatchPackSize
<
6
,
T
>
;
}
else
if
(
num_dims
==
7
)
{
}
else
if
(
num_dims
==
7
)
{
func
=
DispatchPackSize
<
7
,
T
>
;
func
=
DispatchPackSize
<
7
,
T
>
;
}
else
if
(
num_dims
==
8
)
{
}
else
if
(
num_dims
==
8
)
{
func
=
DispatchPackSize
<
8
,
T
>
;
func
=
DispatchPackSize
<
8
,
T
>
;
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
func
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
pad_val
);
func
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
pad_val
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SimplifyThenLaunch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
void
SimplifyThenLaunch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
void
*
dst
)
{
void
*
dst
)
{
CHECK_LE
(
num_dims
,
kMaxNumDims
);
CHECK_LE
(
num_dims
,
kMaxNumDims
);
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_padding_before
[
kMaxNumDims
];
int64_t
simplified_padding_before
[
kMaxNumDims
];
int64_t
simplified_padding_after
[
kMaxNumDims
];
int64_t
simplified_padding_after
[
kMaxNumDims
];
size_t
simplified_num_dims
=
1
;
size_t
simplified_num_dims
=
1
;
SimplifyPadDims
(
num_dims
,
src_dims
,
padding_before
,
padding_after
,
&
simplified_num_dims
,
SimplifyPadDims
(
num_dims
,
src_dims
,
padding_before
,
padding_after
,
&
simplified_num_dims
,
simplified_dst_dims
,
simplified_src_dims
,
simplified_padding_before
,
simplified_dst_dims
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
);
simplified_padding_after
);
LaunchWithSimplified
<
T
>
(
stream
,
simplified_num_dims
,
dst
,
simplified_dst_dims
,
src
,
LaunchWithSimplified
<
T
>
(
stream
,
simplified_num_dims
,
dst
,
simplified_dst_dims
,
src
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
,
pad_val
);
pad_val
);
}
}
template
<
typename
T
>
template
<
typename
T
>
class
ConstantPadImpl
:
public
ConstantPad
{
class
ConstantPadImpl
:
public
ConstantPad
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadImpl
);
ConstantPadImpl
()
=
default
;
ConstantPadImpl
()
=
default
;
~
ConstantPadImpl
()
override
=
default
;
~
ConstantPadImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
void
Launch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
Scalar
pad_val
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
Scalar
pad_val
,
void
*
dst
)
override
{
void
*
dst
)
override
{
SimplifyThenLaunch
<
T
>
(
stream
,
num_dims
,
src_dims
,
src
,
padding_before
,
padding_after
,
SimplifyThenLaunch
<
T
>
(
stream
,
num_dims
,
src_dims
,
src
,
padding_before
,
padding_after
,
GetValue
<
T
>
(
pad_val
),
dst
);
GetValue
<
T
>
(
pad_val
),
dst
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
std
::
unique_ptr
<
ConstantPad
>
NewConstantPad
()
{
std
::
unique_ptr
<
ConstantPad
>
NewConstantPad
()
{
return
std
::
unique_ptr
<
ConstantPad
>
(
new
ConstantPadImpl
<
T
>
());
return
std
::
unique_ptr
<
ConstantPad
>
(
new
ConstantPadImpl
<
T
>
());
}
}
class
ConstantPadFactoryImpl
:
public
ConstantPadFactory
{
class
ConstantPadFactoryImpl
:
public
ConstantPadFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadFactoryImpl
);
ConstantPadFactoryImpl
()
=
default
;
ConstantPadFactoryImpl
()
=
default
;
~
ConstantPadFactoryImpl
()
override
=
default
;
~
ConstantPadFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ConstantPad
>
New
(
DataType
data_type
)
override
{
std
::
unique_ptr
<
ConstantPad
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
ConstantPad
>
()
>>
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
ConstantPad
>
()
>>
new_constant_pad_handle
{
new_constant_pad_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_CONSTANT_PAD_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_CONSTANT_PAD_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
const
auto
it
=
new_constant_pad_handle
.
find
(
data_type
);
const
auto
it
=
new_constant_pad_handle
.
find
(
data_type
);
if
(
it
!=
new_constant_pad_handle
.
end
())
{
if
(
it
!=
new_constant_pad_handle
.
end
())
{
return
it
->
second
();
return
it
->
second
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ConstantPadFactory
,
ConstantPadFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ConstantPadFactory
,
ConstantPadFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
CopyNdKernel
(
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
__global__
void
CopyNdKernel
(
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
copy_index
[
num_dims
];
IndexType
copy_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
copy_index_helper
.
OffsetToNdIndex
(
i
,
copy_index
);
params
.
copy_index_helper
.
OffsetToNdIndex
(
i
,
copy_index
);
#pragma unroll
#pragma unroll
for
(
size_t
j
=
0
;
j
<
num_dims
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
num_dims
;
++
j
)
{
src_index
[
j
]
=
params
.
src_pos
[
j
]
+
copy_index
[
j
];
src_index
[
j
]
=
params
.
src_pos
[
j
]
+
copy_index
[
j
];
dst_index
[
j
]
=
params
.
dst_pos
[
j
]
+
copy_index
[
j
];
dst_index
[
j
]
=
params
.
dst_pos
[
j
]
+
copy_index
[
j
];
}
}
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
const
IndexType
dst_offset
=
params
.
dst_index_helper
.
NdIndexToOffset
(
dst_index
);
const
IndexType
dst_offset
=
params
.
dst_index_helper
.
NdIndexToOffset
(
dst_index
);
dst
[
dst_offset
]
=
src
[
src_offset
];
dst
[
dst_offset
]
=
src
[
src_offset
];
}
}
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
void
LaunchKernel
(
Stream
*
stream
,
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
CopyNdKernel
<
num_dims
,
movement_size
,
IndexType
>
CopyNdKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
class
CopyNdImpl
:
public
CopyNd
{
class
CopyNdImpl
:
public
CopyNd
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdImpl
);
CopyNdImpl
()
=
default
;
CopyNdImpl
()
=
default
;
~
CopyNdImpl
()
override
=
default
;
~
CopyNdImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
void
*
dst
,
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_pos
,
const
void
*
src
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_pos
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_pos
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_pos
,
const
int64_t
*
extent
)
const
override
{
const
int64_t
*
extent
)
const
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
dst
,
dst_dims
,
dst_pos
,
src
,
src_dims
,
src_pos
,
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
dst
,
dst_dims
,
dst_pos
,
src
,
src_dims
,
src_pos
,
extent
);
extent
);
}
}
};
};
class
CopyNdFactoryImpl
:
public
CopyNdFactory
{
class
CopyNdFactoryImpl
:
public
CopyNdFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdFactoryImpl
);
CopyNdFactoryImpl
()
=
default
;
CopyNdFactoryImpl
()
=
default
;
~
CopyNdFactoryImpl
()
override
=
default
;
~
CopyNdFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
CopyNd
>
New
(
size_t
max_num_dims
)
override
{
std
::
unique_ptr
<
CopyNd
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
CopyNd
>
(
new
CopyNdImpl
());
return
std
::
unique_ptr
<
CopyNd
>
(
new
CopyNdImpl
());
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CopyNdFactory
,
CopyNdFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CopyNdFactory
,
CopyNdFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
class
ElementwiseUnaryImpl
:
public
ElementwiseUnary
{
class
ElementwiseUnaryImpl
:
public
ElementwiseUnary
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryImpl
);
ElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
ElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
ElementwiseUnaryImpl
()
override
=
default
;
~
ElementwiseUnaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
src
,
void
*
dst
,
size_t
count
)
override
{
void
Launch
(
Stream
*
stream
,
const
void
*
src
,
void
*
dst
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Dst
,
Src
>
(
attr0
,
attr1
);
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Dst
,
Src
>
(
attr0
,
attr1
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
decltype
(
functor
),
Dst
,
Src
>
(
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
decltype
(
functor
),
Dst
,
Src
>
(
functor
,
count
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src
),
functor
,
count
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src
),
cuda_stream
->
cuda_stream
())));
cuda_stream
->
cuda_stream
())));
}
}
protected:
protected:
Scalar
attr0
,
attr1
;
Scalar
attr0
,
attr1
;
};
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
ElementwiseUnary
>
NewElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
std
::
unique_ptr
<
ElementwiseUnary
>
NewElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
ElementwiseUnary
>
(
return
std
::
unique_ptr
<
ElementwiseUnary
>
(
new
ElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
new
ElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
}
class
ElementwiseUnaryFactoryImpl
:
public
ElementwiseUnaryFactory
{
class
ElementwiseUnaryFactoryImpl
:
public
ElementwiseUnaryFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryFactoryImpl
);
ElementwiseUnaryFactoryImpl
()
=
default
;
ElementwiseUnaryFactoryImpl
()
=
default
;
~
ElementwiseUnaryFactoryImpl
()
override
=
default
;
~
ElementwiseUnaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
)
override
{
DataType
dst_dtype
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
Scalar
(),
Scalar
());
return
New
(
unary_op
,
src_type
,
dst_dtype
,
Scalar
(),
Scalar
());
}
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
)
override
{
Scalar
attr0
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
attr0
,
Scalar
());
return
New
(
unary_op
,
src_type
,
dst_dtype
,
attr0
,
Scalar
());
}
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
,
Scalar
attr1
)
override
{
Scalar
attr0
,
Scalar
attr1
)
override
{
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{
std
::
make_tuple
(
unary_op
,
OF_PP_PAIR_SECOND
(
dtype_pair
),
OF_PP_PAIR_SECOND
(
dtype_pair
)),
\
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewElementwiseUnary
<
unary_op
,
OF_PP_PAIR_FIRST
(
dtype_pair
),
OF_PP_PAIR_FIRST
(
dtype_pair
)
>
},
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{
std
::
make_tuple
(
unary_op
,
OF_PP_PAIR_SECOND
(
src_type_pair
),
OF_PP_PAIR_SECOND
(
dst_dtype_pair
)),
\
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \
NewElementwiseUnary
<
unary_op
,
OF_PP_PAIR_FIRST
(
src_type_pair
),
\
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \
OF_PP_PAIR_FIRST
(
dst_dtype_pair
)
>
},
OF_PP_PAIR_FIRST(dst_dtype_pair)>},
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
ElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
std
::
function
<
std
::
unique_ptr
<
ElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
new_elementwise_unary_handle
{
new_elementwise_unary_handle
{
// For All Type OP
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
UNARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
// For Float Type OP
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_FLOATING_MATH_OP_SEQ
,
UNARY_FLOATING_MATH_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
// For Utils OP
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_UTILS_OP_SEQ
,
UTIL_OPS_DATA_TYPE_SEQ
,
UNARY_UTILS_OP_SEQ
,
UTIL_OPS_DATA_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
// For Logical OP
// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
UNARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)};
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const
auto
it
=
const
auto
it
=
new_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_dtype
));
new_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_dtype
));
if
(
it
!=
new_elementwise_unary_handle
.
end
())
{
if
(
it
!=
new_elementwise_unary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
return
it
->
second
(
attr0
,
attr1
);
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ElementwiseUnaryFactory
,
ElementwiseUnaryFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ElementwiseUnaryFactory
,
ElementwiseUnaryFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/fill.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
template
<
size_t
size
>
template
<
size_t
size
>
using
Storage
=
typename
std
::
aligned_storage
<
size
,
size
>::
type
;
using
Storage
=
typename
std
::
aligned_storage
<
size
,
size
>::
type
;
template
<
typename
T
,
size_t
pack
>
template
<
typename
T
,
size_t
pack
>
union
Pack
{
union
Pack
{
static
constexpr
size_t
size
=
sizeof
(
T
)
*
pack
;
static
constexpr
size_t
size
=
sizeof
(
T
)
*
pack
;
explicit
__device__
__host__
Pack
(
T
value
)
{
explicit
__device__
__host__
Pack
(
T
value
)
{
static_assert
(
sizeof
(
Pack
)
==
size
,
""
);
static_assert
(
sizeof
(
Pack
)
==
size
,
""
);
static_assert
(
alignof
(
Pack
)
==
size
,
""
);
static_assert
(
alignof
(
Pack
)
==
size
,
""
);
#pragma unroll
#pragma unroll
for
(
size_t
i
=
0
;
i
<
pack
;
++
i
)
{
elem
[
i
]
=
value
;
}
for
(
size_t
i
=
0
;
i
<
pack
;
++
i
)
{
elem
[
i
]
=
value
;
}
}
}
T
elem
[
pack
];
T
elem
[
pack
];
Storage
<
size
>
storage
;
Storage
<
size
>
storage
;
};
};
template
<
typename
T
,
size_t
pack
>
template
<
typename
T
,
size_t
pack
>
__global__
void
FillGpu
(
T
*
dst
,
T
value
,
size_t
count
)
{
__global__
void
FillGpu
(
T
*
dst
,
T
value
,
size_t
count
)
{
const
size_t
pack_count
=
count
/
pack
;
const
size_t
pack_count
=
count
/
pack
;
Pack
<
T
,
pack
>
pack_value
(
value
);
Pack
<
T
,
pack
>
pack_value
(
value
);
auto
*
pack_dst
=
reinterpret_cast
<
decltype
(
pack_value
.
storage
)
*>
(
dst
);
auto
*
pack_dst
=
reinterpret_cast
<
decltype
(
pack_value
.
storage
)
*>
(
dst
);
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
pack_count
)
{
pack_dst
[
i
]
=
pack_value
.
storage
;
}
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
pack_count
)
{
pack_dst
[
i
]
=
pack_value
.
storage
;
}
T
*
tail_dst
=
dst
+
pack_count
*
pack
;
T
*
tail_dst
=
dst
+
pack_count
*
pack
;
const
size_t
tail_count
=
count
-
pack_count
*
pack
;
const
size_t
tail_count
=
count
-
pack_count
*
pack
;
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
tail_count
)
{
tail_dst
[
i
]
=
value
;
}
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
tail_count
)
{
tail_dst
[
i
]
=
value
;
}
}
}
template
<
typename
T
>
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
return
value
.
Value
<
T
>
();
}
}
template
<
>
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
}
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// template<>
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// }
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
template
<
typename
T
,
size_t
pack
>
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
T
value
,
size_t
count
)
{
FillGpu
<
T
,
pack
>
FillGpu
<
T
,
pack
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
dst
,
value
,
count
);
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
dst
,
value
,
count
);
}
}
template
<
typename
T
,
size_t
pack
>
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
T
value
,
size_t
count
)
{
LOG
(
FATAL
)
<<
"wrong alignment"
;
LOG
(
FATAL
)
<<
"wrong alignment"
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
LaunchFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
void
LaunchFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
if
(
uintptr
%
16
==
0
)
{
if
(
uintptr
%
16
==
0
)
{
LaunchPackFill
<
T
,
16
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
LaunchPackFill
<
T
,
16
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
8
==
0
)
{
}
else
if
(
uintptr
%
8
==
0
)
{
LaunchPackFill
<
T
,
8
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
LaunchPackFill
<
T
,
8
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
4
==
0
)
{
}
else
if
(
uintptr
%
4
==
0
)
{
LaunchPackFill
<
T
,
4
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
LaunchPackFill
<
T
,
4
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
2
==
0
)
{
}
else
if
(
uintptr
%
2
==
0
)
{
LaunchPackFill
<
T
,
2
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
LaunchPackFill
<
T
,
2
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
{
}
else
{
LaunchPackFill
<
T
,
1
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
LaunchPackFill
<
T
,
1
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
class
FillImpl
:
public
Fill
{
class
FillImpl
:
public
Fill
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
FillImpl
);
FillImpl
()
=
default
;
FillImpl
()
=
default
;
~
FillImpl
()
override
=
default
;
~
FillImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
Scalar
value
,
size_t
count
)
override
{
void
Launch
(
Stream
*
stream
,
void
*
dst
,
Scalar
value
,
size_t
count
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
LaunchFill
<
T
>
(
cuda_stream
,
reinterpret_cast
<
T
*>
(
dst
),
GetValue
<
T
>
(
value
),
count
);
LaunchFill
<
T
>
(
cuda_stream
,
reinterpret_cast
<
T
*>
(
dst
),
GetValue
<
T
>
(
value
),
count
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
std
::
unique_ptr
<
Fill
>
NewFill
()
{
std
::
unique_ptr
<
Fill
>
NewFill
()
{
return
std
::
unique_ptr
<
Fill
>
(
new
FillImpl
<
T
>
());
return
std
::
unique_ptr
<
Fill
>
(
new
FillImpl
<
T
>
());
}
}
class
FillFactoryImpl
:
public
FillFactory
{
class
FillFactoryImpl
:
public
FillFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
FillFactoryImpl
);
FillFactoryImpl
()
=
default
;
FillFactoryImpl
()
=
default
;
~
FillFactoryImpl
()
override
=
default
;
~
FillFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Fill
>
New
(
DataType
data_type
)
override
{
std
::
unique_ptr
<
Fill
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
Fill
>
()
>>
new_fill_handle
{
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
Fill
>
()
>>
new_fill_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_FILL_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_FILL_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_FILL_ENTRY
#undef MAKE_NEW_FILL_ENTRY
const
auto
it
=
new_fill_handle
.
find
(
data_type
);
const
auto
it
=
new_fill_handle
.
find
(
data_type
);
if
(
it
!=
new_fill_handle
.
end
())
{
if
(
it
!=
new_fill_handle
.
end
())
{
return
it
->
second
();
return
it
->
second
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
FillFactory
,
FillFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
FillFactory
,
FillFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/memcpy.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifdef WITH_ROCM
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
class
MemcpyImpl
:
public
Memcpy
{
class
MemcpyImpl
:
public
Memcpy
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyImpl
);
MemcpyImpl
()
=
default
;
MemcpyImpl
()
=
default
;
~
MemcpyImpl
()
override
=
default
;
~
MemcpyImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
const
void
*
src
,
size_t
count
)
override
{
void
Launch
(
Stream
*
stream
,
void
*
dst
,
const
void
*
src
,
size_t
count
)
override
{
if
(
dst
==
src
)
{
return
;
}
if
(
dst
==
src
)
{
return
;
}
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemcpyAsync
(
dst
,
src
,
count
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
dst
,
src
,
count
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
}
};
};
class
MemcpyFactoryImpl
:
public
MemcpyFactory
{
class
MemcpyFactoryImpl
:
public
MemcpyFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyFactoryImpl
);
MemcpyFactoryImpl
()
=
default
;
MemcpyFactoryImpl
()
=
default
;
~
MemcpyFactoryImpl
()
override
=
default
;
~
MemcpyFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memcpy
>
New
(
MemcpyKind
kind
)
override
{
std
::
unique_ptr
<
Memcpy
>
New
(
MemcpyKind
kind
)
override
{
return
std
::
unique_ptr
<
Memcpy
>
(
new
MemcpyImpl
());
return
std
::
unique_ptr
<
Memcpy
>
(
new
MemcpyImpl
());
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemcpyFactory
,
MemcpyFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemcpyFactory
,
MemcpyFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
#endif
#endif
oneflow/core/ep/rocm/primitive/memset.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifdef WITH_ROCM
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
class
MemsetImpl
:
public
Memset
{
class
MemsetImpl
:
public
Memset
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
MemsetImpl
);
MemsetImpl
()
=
default
;
MemsetImpl
()
=
default
;
~
MemsetImpl
()
override
=
default
;
~
MemsetImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
ptr
,
int
value
,
size_t
count
)
override
{
void
Launch
(
Stream
*
stream
,
void
*
ptr
,
int
value
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
ptr
,
value
,
count
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemsetAsync
(
ptr
,
value
,
count
,
cuda_stream
->
cuda_stream
()));
}
}
};
};
class
MemsetFactoryImpl
:
public
MemsetFactory
{
class
MemsetFactoryImpl
:
public
MemsetFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
MemsetFactoryImpl
);
MemsetFactoryImpl
()
=
default
;
MemsetFactoryImpl
()
=
default
;
~
MemsetFactoryImpl
()
override
=
default
;
~
MemsetFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memset
>
New
()
override
{
return
std
::
unique_ptr
<
Memset
>
(
new
MemsetImpl
());
}
std
::
unique_ptr
<
Memset
>
New
()
override
{
return
std
::
unique_ptr
<
Memset
>
(
new
MemsetImpl
());
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemsetFactory
,
MemsetFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemsetFactory
,
MemsetFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
#endif
#endif
oneflow/core/ep/rocm/primitive/permute.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
permute
{
namespace
permute
{
namespace
internal
{
namespace
internal
{
namespace
{
namespace
{
constexpr
int32_t
kMov4TileSize
=
32
;
constexpr
int32_t
kMov4TileSize
=
32
;
constexpr
int32_t
kMov2TileSize
=
64
;
constexpr
int32_t
kMov2TileSize
=
64
;
constexpr
int32_t
kBlockRows
=
8
;
constexpr
int32_t
kBlockRows
=
8
;
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
PermuteKernel
(
PermuteKernelParams
<
num_dims
,
IndexType
>
params
)
{
__global__
void
PermuteKernel
(
PermuteKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
i
,
dst_index
);
params
.
dst_index_helper
.
OffsetToNdIndex
(
i
,
dst_index
);
#pragma unroll
#pragma unroll
for
(
size_t
dim
=
0
;
dim
<
num_dims
;
++
dim
)
{
for
(
size_t
dim
=
0
;
dim
<
num_dims
;
++
dim
)
{
src_index
[
params
.
permutation
[
dim
]]
=
dst_index
[
dim
];
src_index
[
params
.
permutation
[
dim
]]
=
dst_index
[
dim
];
}
}
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst
[
i
]
=
src
[
src_offset
];
dst
[
i
]
=
src
[
src_offset
];
}
}
}
}
// (B, X, Y) -> (B, Y, X)
// (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeKernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
__global__
void
BatchTransposeKernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
const
IndexType
dst_cols
=
rows
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
__shared__
T
tile
[
tile_size
][
tile_size
+
1
];
// To avoid bank conflict.
__shared__
T
tile
[
tile_size
][
tile_size
+
1
];
// To avoid bank conflict.
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
src_ptr
);
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
src_ptr
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
dst_ptr
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
dst_ptr
);
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
const
IndexType
tile_col_index
=
tile_index
tile_index
-
tile_row_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile
[
row_in_tile
][
col_in_tile
]
=
src
[
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
];
tile
[
row_in_tile
][
col_in_tile
]
=
src
[
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
];
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
{
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
dst
[
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
]
=
tile
[
col_in_tile
][
row_in_tile
];
dst
[
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
]
=
tile
[
col_in_tile
][
row_in_tile
];
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
}
}
}
}
/*
/*
Here is a Movementsie=2 version of Batch Transpose.
Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4.
movementsize=4.
*/
*/
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeMovement2Kernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
__global__
void
BatchTransposeMovement2Kernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
const
IndexType
dst_cols
=
rows
;
static_assert
(
tile_size
%
2
==
0
,
""
);
static_assert
(
tile_size
%
2
==
0
,
""
);
using
T_MOV2
=
typename
std
::
aligned_storage
<
2
,
2
>::
type
;
using
T_MOV2
=
typename
std
::
aligned_storage
<
2
,
2
>::
type
;
using
T_MOV4
=
typename
std
::
aligned_storage
<
4
,
4
>::
type
;
using
T_MOV4
=
typename
std
::
aligned_storage
<
4
,
4
>::
type
;
const
T_MOV4
*
src
=
reinterpret_cast
<
const
T_MOV4
*>
(
src_ptr
);
const
T_MOV4
*
src
=
reinterpret_cast
<
const
T_MOV4
*>
(
src_ptr
);
T_MOV4
*
dst
=
reinterpret_cast
<
T_MOV4
*>
(
dst_ptr
);
T_MOV4
*
dst
=
reinterpret_cast
<
T_MOV4
*>
(
dst_ptr
);
// Use union structure to process Load and Store.
// Use union structure to process Load and Store.
__shared__
union
{
__shared__
union
{
T_MOV2
tile_m2
[
tile_size
][
tile_size
+
2
];
// half [64][66]
T_MOV2
tile_m2
[
tile_size
][
tile_size
+
2
];
// half [64][66]
T_MOV4
tile_m4
[
tile_size
][
tile_size
/
2
+
1
];
// half2 [64][33]
T_MOV4
tile_m4
[
tile_size
][
tile_size
/
2
+
1
];
// half2 [64][33]
}
tile_mem
;
}
tile_mem
;
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
const
IndexType
tile_col_index
=
tile_index
tile_index
-
tile_row_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
*
2
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile_mem
.
tile_m4
[
row_in_tile
][
col_in_tile
]
=
tile_mem
.
tile_m4
[
row_in_tile
][
col_in_tile
]
=
src
[(
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
)
/
2
];
src
[(
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
)
/
2
];
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
{
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
*
2
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
union
{
union
{
T_MOV4
m4
;
T_MOV4
m4
;
T_MOV2
m2
[
2
];
T_MOV2
m2
[
2
];
}
tmp_storage
;
}
tmp_storage
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
tmp_storage
.
m2
[
0
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
][
row_in_tile
];
tmp_storage
.
m2
[
0
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
][
row_in_tile
];
tmp_storage
.
m2
[
1
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
+
1
][
row_in_tile
];
tmp_storage
.
m2
[
1
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
+
1
][
row_in_tile
];
dst
[(
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
)
/
2
]
=
tmp_storage
.
m4
;
dst
[(
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
)
/
2
]
=
tmp_storage
.
m4
;
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
}
}
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
void
LaunchBatchTransposeKernel
(
hipStream_t
&
cuda_stream
,
void
LaunchBatchTransposeKernel
(
hipStream_t
&
cuda_stream
,
const
PermuteKernelParams
<
num_dims
,
IndexType
>&
params
,
const
PermuteKernelParams
<
num_dims
,
IndexType
>&
params
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
const
IndexType
&
cols
)
{
IndexType
num_tile_rows
=
(
rows
+
tile_size
-
1
)
/
tile_size
;
IndexType
num_tile_rows
=
(
rows
+
tile_size
-
1
)
/
tile_size
;
IndexType
num_tile_cols
=
(
cols
+
tile_size
-
1
)
/
tile_size
;
IndexType
num_tile_cols
=
(
cols
+
tile_size
-
1
)
/
tile_size
;
const
int32_t
block_nums
=
num_batches
*
num_tile_rows
*
num_tile_cols
;
const
int32_t
block_nums
=
num_batches
*
num_tile_rows
*
num_tile_cols
;
int32_t
launched_block_nums
=
std
::
min
(
block_nums
,
kCudaMaxBlocksNum
);
int32_t
launched_block_nums
=
std
::
min
(
block_nums
,
kCudaMaxBlocksNum
);
if
(
tile_size
==
kMov2TileSize
)
{
if
(
tile_size
==
kMov2TileSize
)
{
const
int32_t
half2_thread
=
tile_size
/
2
;
// cause each thread process two half elements.
const
int32_t
half2_thread
=
tile_size
/
2
;
// cause each thread process two half elements.
BatchTransposeMovement2Kernel
<
num_dims
,
kMov2TileSize
,
IndexType
>
BatchTransposeMovement2Kernel
<
num_dims
,
kMov2TileSize
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
half2_thread
,
kBlockRows
),
0
,
cuda_stream
>>>
(
<<<
launched_block_nums
,
dim3
(
half2_thread
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
// Set threads num as 32x8 cause each threads
block_nums
);
// Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory.
// process 4 elements to 64x66 half share memory.
}
else
{
}
else
{
BatchTransposeKernel
<
num_dims
,
movement_size
,
tile_size
,
IndexType
>
BatchTransposeKernel
<
num_dims
,
movement_size
,
tile_size
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
tile_size
,
kBlockRows
),
0
,
cuda_stream
>>>
(
<<<
launched_block_nums
,
dim3
(
tile_size
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
}
}
}
}
template
<
size_t
tile_size
,
typename
IndexType
>
template
<
size_t
tile_size
,
typename
IndexType
>
bool
CheckIfGreaterEqualThanTileSize
(
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
bool
CheckIfGreaterEqualThanTileSize
(
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
rows
<
tile_size
||
cols
<
tile_size
)
{
return
false
;
}
if
(
rows
<
tile_size
||
cols
<
tile_size
)
{
return
false
;
}
return
true
;
return
true
;
}
}
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
bool
CheckLaunchBatchTranspose
(
const
int
*
permutation
,
const
IndexType
&
num_batches
,
bool
CheckLaunchBatchTranspose
(
const
int
*
permutation
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
CheckIfGreaterEqualThanTileSize
<
tile_size
,
IndexType
>
(
rows
,
cols
))
{
if
(
CheckIfGreaterEqualThanTileSize
<
tile_size
,
IndexType
>
(
rows
,
cols
))
{
if
(
num_batches
==
1
&&
permutation
[
1
]
==
0
&&
permutation
[
0
]
==
1
)
{
if
(
num_batches
==
1
&&
permutation
[
1
]
==
0
&&
permutation
[
0
]
==
1
)
{
// 2d tensor case: (0, 1) -> (1, 0)
// 2d tensor case: (0, 1) -> (1, 0)
return
true
;
return
true
;
}
else
if
(
num_dims
==
3
&&
permutation
[
2
]
==
1
&&
permutation
[
1
]
==
2
)
{
}
else
if
(
num_dims
==
3
&&
permutation
[
2
]
==
1
&&
permutation
[
1
]
==
2
)
{
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return
true
;
return
true
;
}
else
{
}
else
{
return
false
;
return
false
;
}
}
}
}
return
false
;
return
false
;
}
}
template
<
typename
IndexType
,
size_t
movement_size
>
template
<
typename
IndexType
,
size_t
movement_size
>
bool
CheckUseMov2
(
const
IndexType
&
rows
,
const
IndexType
&
cols
,
const
void
*
src
,
void
*
dst
)
{
bool
CheckUseMov2
(
const
IndexType
&
rows
,
const
IndexType
&
cols
,
const
void
*
src
,
void
*
dst
)
{
auto
src_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
src
);
auto
src_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
src
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
return
(
movement_size
==
2
)
&&
(
rows
%
2
==
0
)
&&
(
cols
%
2
==
0
)
&&
(
src_ptr
%
4
==
0
)
return
(
movement_size
==
2
)
&&
(
rows
%
2
==
0
)
&&
(
cols
%
2
==
0
)
&&
(
src_ptr
%
4
==
0
)
&&
(
dst_ptr
%
4
==
0
);
&&
(
dst_ptr
%
4
==
0
);
}
}
template
<
size_t
num_dims
,
typename
IndexType
>
template
<
size_t
num_dims
,
typename
IndexType
>
void
InferBatchTransposeShape
(
const
int64_t
*
src_dims
,
IndexType
*
num_batches
,
IndexType
*
rows
,
void
InferBatchTransposeShape
(
const
int64_t
*
src_dims
,
IndexType
*
num_batches
,
IndexType
*
rows
,
IndexType
*
cols
)
{
IndexType
*
cols
)
{
if
(
num_dims
==
2
)
{
if
(
num_dims
==
2
)
{
*
num_batches
=
1
;
*
num_batches
=
1
;
*
rows
=
src_dims
[
0
];
*
rows
=
src_dims
[
0
];
*
cols
=
src_dims
[
1
];
*
cols
=
src_dims
[
1
];
}
else
{
}
else
{
*
num_batches
=
src_dims
[
0
];
*
num_batches
=
src_dims
[
0
];
*
rows
=
src_dims
[
1
];
*
rows
=
src_dims
[
1
];
*
cols
=
src_dims
[
2
];
*
cols
=
src_dims
[
2
];
}
}
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
LaunchKernel
(
Stream
*
stream
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
,
size_t
count
)
{
void
*
dst
,
size_t
count
)
{
PermuteKernelParams
<
num_dims
,
IndexType
>
params
=
PermuteKernelParams
<
num_dims
,
IndexType
>
params
=
MakePermuteParams
<
num_dims
,
IndexType
>
(
src_dims
,
src
,
permutation
,
dst
,
count
);
MakePermuteParams
<
num_dims
,
IndexType
>
(
src_dims
,
src
,
permutation
,
dst
,
count
);
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
if
(
num_dims
==
2
||
num_dims
==
3
)
{
if
(
num_dims
==
2
||
num_dims
==
3
)
{
IndexType
num_batches
;
IndexType
num_batches
;
IndexType
rows
;
IndexType
rows
;
IndexType
cols
;
IndexType
cols
;
InferBatchTransposeShape
<
num_dims
,
IndexType
>
(
src_dims
,
&
num_batches
,
&
rows
,
&
cols
);
InferBatchTransposeShape
<
num_dims
,
IndexType
>
(
src_dims
,
&
num_batches
,
&
rows
,
&
cols
);
if
(
CheckLaunchBatchTranspose
<
num_dims
,
kMov4TileSize
>
(
params
.
permutation
,
num_batches
,
rows
,
if
(
CheckLaunchBatchTranspose
<
num_dims
,
kMov4TileSize
>
(
params
.
permutation
,
num_batches
,
rows
,
cols
))
{
cols
))
{
if
(
CheckUseMov2
<
IndexType
,
movement_size
>
(
rows
,
cols
,
src
,
dst
))
{
if
(
CheckUseMov2
<
IndexType
,
movement_size
>
(
rows
,
cols
,
src
,
dst
))
{
LaunchBatchTransposeKernel
<
num_dims
,
2
,
kMov2TileSize
,
IndexType
>
(
cuda_stream
,
params
,
LaunchBatchTransposeKernel
<
num_dims
,
2
,
kMov2TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
num_batches
,
rows
,
cols
);
}
else
{
}
else
{
LaunchBatchTransposeKernel
<
num_dims
,
movement_size
,
kMov4TileSize
,
IndexType
>
(
LaunchBatchTransposeKernel
<
num_dims
,
movement_size
,
kMov4TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
}
}
}
else
{
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
}
else
{
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
}
}
class
PermuteImpl
:
public
Permute
{
class
PermuteImpl
:
public
Permute
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
PermuteImpl
);
PermuteImpl
()
=
default
;
PermuteImpl
()
=
default
;
~
PermuteImpl
()
override
=
default
;
~
PermuteImpl
()
override
=
default
;
using
Permute
::
Launch
;
using
Permute
::
Launch
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
)
override
{
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
)
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
src_dims
,
src
,
permutation
,
dst
);
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
src_dims
,
src
,
permutation
,
dst
);
}
}
};
};
class
PermuteFactoryImpl
:
public
PermuteFactory
{
class
PermuteFactoryImpl
:
public
PermuteFactory
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
PermuteFactoryImpl
);
PermuteFactoryImpl
()
=
default
;
PermuteFactoryImpl
()
=
default
;
~
PermuteFactoryImpl
()
override
=
default
;
~
PermuteFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Permute
>
New
(
size_t
max_num_dims
)
override
{
std
::
unique_ptr
<
Permute
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
Permute
>
(
new
PermuteImpl
());
return
std
::
unique_ptr
<
Permute
>
(
new
PermuteImpl
());
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
PermuteFactory
,
PermuteFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
PermuteFactory
,
PermuteFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace internal
}
// namespace internal
}
// namespace permute
}
// namespace permute
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
enum
class
Algorithm
{
enum
class
Algorithm
{
kSoftmax
,
kSoftmax
,
kLogSoftmax
,
kLogSoftmax
,
};
};
template
<
Algorithm
algorithm
,
typename
T
>
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
x
,
T
*
y
)
{
void
SoftmaxGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
x
,
T
*
y
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
oneflow
::
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load
(
x
,
cols
);
oneflow
::
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load
(
x
,
cols
);
oneflow
::
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
y
,
cols
);
oneflow
::
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
y
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
}
}
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxImpl
:
public
SoftmaxBase
{
class
SoftmaxImpl
:
public
SoftmaxBase
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxImpl
);
SoftmaxImpl
()
=
default
;
SoftmaxImpl
()
=
default
;
~
SoftmaxImpl
()
override
=
default
;
~
SoftmaxImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
x
,
void
*
y
)
override
{
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
x
,
void
*
y
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
x
),
SoftmaxGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
x
),
reinterpret_cast
<
T
*>
(
y
));
reinterpret_cast
<
T
*>
(
y
));
}
}
};
};
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBase
>
NewSoftmax
()
{
std
::
unique_ptr
<
SoftmaxBase
>
NewSoftmax
()
{
return
std
::
unique_ptr
<
SoftmaxBase
>
(
new
SoftmaxImpl
<
SoftmaxBase
,
algorithm
,
T
>
());
return
std
::
unique_ptr
<
SoftmaxBase
>
(
new
SoftmaxImpl
<
SoftmaxBase
,
algorithm
,
T
>
());
}
}
template
<
typename
FactoryBase
,
typename
SoftmaxBase
,
Algorithm
algorithm
>
template
<
typename
FactoryBase
,
typename
SoftmaxBase
,
Algorithm
algorithm
>
class
GenericSoftmaxFactoryImpl
:
public
FactoryBase
{
class
GenericSoftmaxFactoryImpl
:
public
FactoryBase
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxFactoryImpl
);
GenericSoftmaxFactoryImpl
()
=
default
;
GenericSoftmaxFactoryImpl
()
=
default
;
~
GenericSoftmaxFactoryImpl
()
override
=
default
;
~
GenericSoftmaxFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBase
>
New
(
DataType
data_type
)
override
{
std
::
unique_ptr
<
SoftmaxBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{
type_proto
,
NewSoftmax
<
SoftmaxBase
,
algorithm
,
type_cpp
>
},
{type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBase
>
()
>>
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBase
>
()
>>
new_softmax_handle
{
new_softmax_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_handle
.
find
(
data_type
);
const
auto
it
=
new_softmax_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_handle
.
end
())
{
if
(
it
!=
new_softmax_handle
.
end
())
{
return
it
->
second
();
return
it
->
second
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
using
SoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
SoftmaxFactory
,
Softmax
,
Algorithm
::
kSoftmax
>
;
using
SoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
SoftmaxFactory
,
Softmax
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxFactoryImpl
=
using
LogSoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
LogSoftmaxFactory
,
LogSoftmax
,
Algorithm
::
kLogSoftmax
>
;
GenericSoftmaxFactoryImpl
<
LogSoftmaxFactory
,
LogSoftmax
,
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxFactory
,
SoftmaxFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxFactory
,
SoftmaxFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxFactory
,
LogSoftmaxFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxFactory
,
LogSoftmaxFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
{
namespace
{
enum
class
Algorithm
{
enum
class
Algorithm
{
kSoftmax
,
kSoftmax
,
kLogSoftmax
,
kLogSoftmax
,
};
};
template
<
Algorithm
algorithm
,
typename
T
>
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxBackwardGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
y
,
const
T
*
dy
,
void
SoftmaxBackwardGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
y
,
const
T
*
dy
,
T
*
dx
)
{
T
*
dx
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_y
(
y
,
cols
);
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_y
(
y
,
cols
);
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_dy
(
dy
,
cols
);
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_dy
(
dy
,
cols
);
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
dx
,
cols
);
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
dx
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
}
}
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxBackwardImpl
:
public
SoftmaxBackwardBase
{
class
SoftmaxBackwardImpl
:
public
SoftmaxBackwardBase
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxBackwardImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxBackwardImpl
);
SoftmaxBackwardImpl
()
=
default
;
SoftmaxBackwardImpl
()
=
default
;
~
SoftmaxBackwardImpl
()
override
=
default
;
~
SoftmaxBackwardImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
y
,
const
void
*
dy
,
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
y
,
const
void
*
dy
,
void
*
dx
)
override
{
void
*
dx
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxBackwardGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
y
),
SoftmaxBackwardGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
y
),
reinterpret_cast
<
const
T
*>
(
dy
),
reinterpret_cast
<
T
*>
(
dx
));
reinterpret_cast
<
const
T
*>
(
dy
),
reinterpret_cast
<
T
*>
(
dx
));
}
}
};
};
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBackwardBase
>
NewSoftmaxBackward
()
{
std
::
unique_ptr
<
SoftmaxBackwardBase
>
NewSoftmaxBackward
()
{
return
std
::
unique_ptr
<
SoftmaxBackwardBase
>
(
return
std
::
unique_ptr
<
SoftmaxBackwardBase
>
(
new
SoftmaxBackwardImpl
<
SoftmaxBackwardBase
,
algorithm
,
T
>
());
new
SoftmaxBackwardImpl
<
SoftmaxBackwardBase
,
algorithm
,
T
>
());
}
}
template
<
typename
BackwardFactoryBase
,
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
>
template
<
typename
BackwardFactoryBase
,
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
>
class
GenericSoftmaxBackwardFactoryImpl
:
public
BackwardFactoryBase
{
class
GenericSoftmaxBackwardFactoryImpl
:
public
BackwardFactoryBase
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxBackwardFactoryImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxBackwardFactoryImpl
);
GenericSoftmaxBackwardFactoryImpl
()
=
default
;
GenericSoftmaxBackwardFactoryImpl
()
=
default
;
~
GenericSoftmaxBackwardFactoryImpl
()
override
=
default
;
~
GenericSoftmaxBackwardFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBackwardBase
>
New
(
DataType
data_type
)
override
{
std
::
unique_ptr
<
SoftmaxBackwardBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{
type_proto
,
NewSoftmaxBackward
<
SoftmaxBackwardBase
,
algorithm
,
type_cpp
>
},
{type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBackwardBase
>
()
>>
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBackwardBase
>
()
>>
new_softmax_backward_handle
{
new_softmax_backward_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_backward_handle
.
find
(
data_type
);
const
auto
it
=
new_softmax_backward_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_backward_handle
.
end
())
{
if
(
it
!=
new_softmax_backward_handle
.
end
())
{
return
it
->
second
();
return
it
->
second
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
}
}
}
}
};
};
using
SoftmaxBackwardFactoryImpl
=
using
SoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
SoftmaxBackwardFactory
,
SoftmaxBackward
,
Algorithm
::
kSoftmax
>
;
GenericSoftmaxBackwardFactoryImpl
<
SoftmaxBackwardFactory
,
SoftmaxBackward
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxBackwardFactoryImpl
=
using
LogSoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
LogSoftmaxBackwardFactory
,
LogSoftmaxBackward
,
GenericSoftmaxBackwardFactoryImpl
<
LogSoftmaxBackwardFactory
,
LogSoftmaxBackward
,
Algorithm
::
kLogSoftmax
>
;
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxBackwardFactory
,
SoftmaxBackwardFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxBackwardFactory
,
SoftmaxBackwardFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxBackwardFactory
,
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxBackwardFactory
,
LogSoftmaxBackwardFactoryImpl
);
LogSoftmaxBackwardFactoryImpl
);
}
// namespace
}
// namespace
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/type_seq.h
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else
// #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
\
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ
\
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
\ No newline at end of file
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
View file @
8f7de847
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
template
<
typename
Dst
,
typename
Src
>
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kGelu
,
Dst
,
Src
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kGelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
src
return
static_cast
<
Src
>
(
0.5
)
*
src
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
src
));
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
src
));
}
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
float
,
float
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
float
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src
)
const
{
return
tanhf
(
src
);
}
OF_DEVICE_FUNC
float
operator
()(
float
src
)
const
{
return
tanhf
(
src
);
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
double
,
double
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
double
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src
)
const
{
return
tanh
(
src
);
}
OF_DEVICE_FUNC
double
operator
()(
double
src
)
const
{
return
tanh
(
src
);
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
half
,
half
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
half
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
return
__float2half
(
tanhf
(
__half2float
(
src
)));
}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
return
__float2half
(
tanhf
(
__half2float
(
src
)));
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
half
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isinf
(
__half2float
(
src
));
}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isinf
(
__half2float
(
src
));
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
float
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isinf
(
src
);
}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isinf
(
src
);
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
double
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isinf
(
src
);
}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isinf
(
src
);
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
half
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isnan
(
__half2float
(
src
));
}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isnan
(
__half2float
(
src
));
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
float
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isnan
(
src
);
}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isnan
(
src
);
}
};
};
template
<
>
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
double
>
{
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isnan
(
src
);
}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isnan
(
src
);
}
};
};
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template
<
>
\
template<> \
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
half
,
half
>
{
\
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
\
UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
\
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
float
,
float
>
float_functor
;
\
UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
\
OF_DEVICE_FUNC half operator()(half src) const { \
return
__float2half
(
float_functor
(
__half2float
(
src
)));
\
return __float2half(float_functor(__half2float(src))); \
}
\
} \
};
};
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kElu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kElu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kCelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kCelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kGelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kGelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kMish
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kMish
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSilu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSilu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftSign
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftSign
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftPlus
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftPlus
);
// /*********nv_bfloat16_kernel*******/
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \
// template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \
// } \
// };
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<>
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// };
// };
// template<>
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// };
// };
// #endif
// #endif
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment