Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
onnxruntime_v14
Commits
1a91fcc2
Commit
1a91fcc2
authored
Jul 25, 2023
by
gaoqiong
Browse files
add dtk所需文件
parent
a144865d
Pipeline
#492
failed with stages
in 0 seconds
Changes
280
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1382 additions
and
0 deletions
+1382
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk.h
...elease/amdgpu/onnxruntime/core/providers/rocm/math/topk.h
+23
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.cuh
...amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.cuh
+474
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.h
...e/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.h
+17
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f16.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_f16.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f32.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_f32.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f64.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_f64.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i16.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_i16.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i32.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_i32.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i64.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_i64.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i8.cu
...dgpu/onnxruntime/core/providers/rocm/math/topk_impl_i8.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u16.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_u16.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u32.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_u32.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u64.cu
...gpu/onnxruntime/core/providers/rocm/math/topk_impl_u64.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u8.cu
...dgpu/onnxruntime/core/providers/rocm/math/topk_impl_u8.cu
+5
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops.cc
...runtime/core/providers/rocm/math/unary_elementwise_ops.cc
+167
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops.h
...xruntime/core/providers/rocm/math/unary_elementwise_ops.h
+116
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops_impl.cu
...me/core/providers/rocm/math/unary_elementwise_ops_impl.cu
+164
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops_impl.h
...ime/core/providers/rocm/math/unary_elementwise_ops_impl.h
+51
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/variadic_elementwise_ops.cc
...time/core/providers/rocm/math/variadic_elementwise_ops.cc
+273
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/variadic_elementwise_ops.h
...ntime/core/providers/rocm/math/variadic_elementwise_ops.h
+42
-0
No files found.
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
rocm
{
template
<
bool
inputk
>
class
TopK
final
:
public
RocmKernel
{
public:
TopK
(
const
OpKernelInfo
&
);
Status
ComputeInternal
(
OpKernelContext
*
)
const
override
;
private:
int64_t
axis_
;
int64_t
largest_
;
int64_t
sorted_
;
mutable
int64_t
K_
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "topk_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "hipcub/hipcub.hpp"
#include <hipcub/backend/rocprim/util_type.hpp>
#include <hipcub/util_allocator.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/backend/rocprim/device/device_radix_sort.hpp>
#include <limits>
//TODO:fix the warnings
#ifdef _MSC_VER
#pragma warning(disable : 4244)
#endif
namespace
onnxruntime
{
namespace
rocm
{
using
namespace
hipcub
;
template
<
typename
T
>
struct
KV
{
T
key
;
int64_t
val
;
};
#define BT GridDim::maxThreadsPerBlock
#define ALIGN(N) static_cast<int64_t>(pow(2, ceil(log2(static_cast<double>(N)))))
#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim)
#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim)
#define TRIVIAL (1 == largest ? type_min : type_max)
#define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define IS_SMALLER(n, m) (n.key < m.key || !(n.key > m.key) && (1 == largest ? n.val > m.val : n.val < m.val))
#define LESS(n, m) ((n) <= (m) ? (n) : (m))
template
<
typename
T
>
__global__
void
BitonicTopK
(
const
T
*
X
,
T
*
V
,
int64_t
*
I
,
const
TArray
<
int64_t
>
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
aligned_K
,
int64_t
largest
,
int64_t
sorted
,
int64_t
dimension
,
int64_t
aligned_dimension
,
T
type_min
,
T
type_max
)
{
int64_t
tid
=
threadIdx
.
x
;
int64_t
bid
=
blockIdx
.
x
;
int64_t
bdim
=
blockDim
.
x
;
extern
__shared__
char
shared_mem
[];
auto
S
=
(
KV
<
T
>*
)(
shared_mem
);
auto
mid_dim
=
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
];
auto
left_dim
=
bid
/
mid_dim
*
elem_nums
[
axis
];
auto
right_dim
=
axis
==
size
-
1
?
0
:
bid
%
elem_nums
[
axis
+
1
];
for
(
auto
i
=
tid
;
i
<
aligned_dimension
;
i
+=
bdim
)
{
S
[
i
].
key
=
i
<
dimension
?
X
[
FROM
(
i
)]
:
TRIVIAL
;
S
[
i
].
val
=
i
;
}
__syncthreads
();
//sort each K
for
(
int64_t
len
=
1
;
len
<
aligned_K
;
len
<<=
1
)
{
auto
dir
=
len
<<
1
;
for
(
auto
inc
=
len
;
inc
>
0
;
inc
>>=
1
)
{
auto
low
=
tid
&
(
inc
-
1
);
auto
i
=
(
tid
<<
1
)
-
low
;
auto
j
=
i
+
inc
;
if
(
j
<
aligned_dimension
)
{
auto
reverse
=
(
dir
&
i
)
==
0
;
auto
swap
=
reverse
^
IS_SMALLER
(
S
[
i
],
S
[
j
]);
if
(
swap
)
{
auto
tmp
=
S
[
i
];
S
[
i
]
=
S
[
j
];
S
[
j
]
=
tmp
;
}
}
__syncthreads
();
}
__syncthreads
();
}
//merge and rebuild K
for
(
int64_t
len
=
aligned_K
;
len
<
aligned_dimension
;
len
<<=
1
)
{
auto
dir
=
len
<<
1
;
auto
i
=
(
tid
<<
1
)
-
(
tid
&
(
len
-
1
));
auto
j
=
i
+
len
;
if
(
i
%
dir
<
aligned_K
&&
j
<
aligned_dimension
)
{
S
[
i
]
=
1
==
largest
?
BIGGER
(
S
[
i
],
S
[
j
])
:
SMALLER
(
S
[
i
],
S
[
j
]);
}
__syncthreads
();
for
(
auto
inc
=
aligned_K
>>
1
;
inc
>
0
;
inc
>>=
1
)
{
auto
ii
=
(
tid
<<
1
)
-
(
tid
&
(
inc
-
1
));
auto
jj
=
ii
+
inc
;
if
(
ii
%
dir
<
aligned_K
&&
jj
<
aligned_dimension
)
{
auto
reverse
=
(
dir
&
ii
)
==
0
;
auto
swap
=
reverse
^
IS_SMALLER
(
S
[
ii
],
S
[
jj
]);
if
(
swap
)
{
auto
tmp
=
S
[
ii
];
S
[
ii
]
=
S
[
jj
];
S
[
jj
]
=
tmp
;
}
}
__syncthreads
();
}
__syncthreads
();
}
//save top K
if
(
1
==
sorted
)
{
if
(
1
==
largest
)
{
auto
start
=
aligned_K
-
K
;
if
(
tid
>=
start
&&
tid
<
aligned_K
)
{
auto
to
=
TO
(
aligned_K
-
1
-
tid
);
V
[
to
]
=
S
[
tid
].
key
;
I
[
to
]
=
S
[
tid
].
val
;
}
}
else
{
if
(
tid
<
K
)
{
auto
to
=
TO
(
tid
);
V
[
to
]
=
S
[
tid
].
key
;
I
[
to
]
=
S
[
tid
].
val
;
}
}
}
else
{
if
(
1
==
largest
)
{
auto
start
=
aligned_K
-
K
;
if
(
tid
<
start
)
{
S
[
tid
].
val
=
aligned_dimension
;
}
}
else
{
if
(
tid
>=
K
&&
tid
<
aligned_K
)
{
S
[
tid
].
val
=
aligned_dimension
;
}
}
__syncthreads
();
//sort by index ascending
for
(
int64_t
len
=
1
;
len
<
aligned_K
;
len
<<=
1
)
{
auto
dir
=
len
<<
1
;
for
(
int64_t
inc
=
len
;
inc
>
0
;
inc
>>=
1
)
{
auto
low
=
tid
&
(
inc
-
1
);
auto
i
=
(
tid
<<
1
)
-
low
;
auto
j
=
i
+
inc
;
if
(
j
<
aligned_K
)
{
auto
reverse
=
(
dir
&
i
)
==
0
;
auto
swap
=
reverse
^
(
S
[
i
].
val
<
S
[
j
].
val
);
if
(
swap
)
{
auto
tmp
=
S
[
i
];
S
[
i
]
=
S
[
j
];
S
[
j
]
=
tmp
;
}
}
__syncthreads
();
}
__syncthreads
();
}
if
(
tid
<
K
)
{
auto
to
=
TO
(
tid
);
V
[
to
]
=
S
[
tid
].
key
;
I
[
to
]
=
S
[
tid
].
val
;
}
}
}
template
<
typename
T
>
__device__
__forceinline__
bool
Equal
(
const
T
&
t0
,
const
T
&
t1
)
{
return
t0
==
t1
;
}
__device__
__forceinline__
bool
Equal
(
const
float
&
t0
,
const
float
&
t1
)
{
return
!
(
t0
>
t1
||
t1
>
t0
);
}
__device__
__forceinline__
bool
Equal
(
const
double
&
t0
,
const
double
&
t1
)
{
return
!
(
t0
>
t1
||
t1
>
t0
);
}
template
<
typename
T
>
__device__
__forceinline__
bool
SamePrefix
(
const
T
*
t0
,
const
T
*
t1
,
int64_t
skip
)
{
return
((
*
t0
)
^
(
*
t1
))
>>
skip
==
0
;
}
__device__
__forceinline__
bool
SamePrefix
(
const
half
*
f0
,
const
half
*
f1
,
int64_t
skip
)
{
return
SamePrefix
((
const
int16_t
*
)
f0
,
(
const
int16_t
*
)
f1
,
skip
);
}
__device__
__forceinline__
bool
SamePrefix
(
const
float
*
f0
,
const
float
*
f1
,
int64_t
skip
)
{
return
SamePrefix
((
const
int32_t
*
)
f0
,
(
const
int32_t
*
)
f1
,
skip
);
}
__device__
__forceinline__
bool
SamePrefix
(
const
double
*
d0
,
const
double
*
d1
,
int64_t
skip
)
{
return
SamePrefix
((
const
int64_t
*
)
d0
,
(
const
int64_t
*
)
d1
,
skip
);
}
template
<
typename
T
>
__device__
__forceinline__
int32_t
Radix
(
const
T
*
t
,
int64_t
skip
)
{
return
((
*
t
)
>>
skip
)
&
255
;
}
__device__
__forceinline__
int32_t
Radix
(
const
half
*
f
,
int64_t
skip
)
{
return
Radix
((
const
int16_t
*
)
f
,
skip
);
}
__device__
__forceinline__
int32_t
Radix
(
const
float
*
f
,
int64_t
skip
)
{
return
Radix
((
const
int32_t
*
)
f
,
skip
);
}
__device__
__forceinline__
int32_t
Radix
(
const
double
*
d
,
int64_t
skip
)
{
return
Radix
((
const
int64_t
*
)
d
,
skip
);
}
template
<
typename
T
>
__device__
void
SetByte
(
T
*
t
,
int64_t
byte
)
{
(
*
t
)
|=
byte
;
}
__device__
__forceinline__
void
SetByte
(
half
*
f
,
int64_t
byte
)
{
SetByte
((
int16_t
*
)
f
,
byte
);
}
__device__
__forceinline__
void
SetByte
(
float
*
f
,
int64_t
byte
)
{
SetByte
((
int32_t
*
)
f
,
byte
);
}
__device__
__forceinline__
void
SetByte
(
double
*
d
,
int64_t
byte
)
{
SetByte
((
int64_t
*
)
d
,
byte
);
}
template
<
typename
T
,
int64_t
THREADS
,
int64_t
KPT
>
__global__
void
RadixTopK
(
const
T
*
X
,
T
*
V
,
int64_t
*
I
,
const
TArray
<
int64_t
>
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
largest
,
int64_t
sorted
,
int64_t
dimension
,
int64_t
XPT
,
T
type_min
,
T
type_max
)
{
auto
tid
=
threadIdx
.
x
;
auto
bid
=
blockIdx
.
x
;
extern
__shared__
char
shared_mem
[];
auto
H
=
(
uint32_t
*
)
shared_mem
;
auto
mid_dim
=
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
];
auto
left_dim
=
bid
/
mid_dim
*
elem_nums
[
axis
];
auto
right_dim
=
axis
==
size
-
1
?
0
:
bid
%
elem_nums
[
axis
+
1
];
T
Kth
=
(
T
)
0
,
sign
=
(
T
)
1
;
typedef
BlockScan
<
uint32_t
,
THREADS
>
BlockScan
;
typedef
BlockReduce
<
uint32_t
,
THREADS
>
BlockReduce
;
typedef
BlockRadixSort
<
T
,
THREADS
,
KPT
,
int64_t
>
BlockRadixSort
;
__shared__
union
{
typename
BlockScan
::
TempStorage
scan
;
typename
BlockReduce
::
TempStorage
reduce
;
typename
BlockRadixSort
::
TempStorage
sort
;
}
temp_storage
;
uint32_t
positive
=
0
,
negative
=
0
;
for
(
int64_t
x_i
=
tid
;
x_i
<
dimension
;
x_i
+=
blockDim
.
x
)
{
T
x
=
X
[
FROM
(
x_i
)];
if
(
x
>
(
T
)
0
)
{
++
positive
;
}
else
if
(
x
<
(
T
)
0
)
{
++
negative
;
}
}
__syncthreads
();
positive
=
BlockReduce
(
temp_storage
.
reduce
).
Sum
(
positive
);
__syncthreads
();
negative
=
BlockReduce
(
temp_storage
.
reduce
).
Sum
(
negative
);
if
(
0
==
tid
)
{
H
[
0
]
=
positive
;
H
[
1
]
=
negative
;
}
__syncthreads
();
positive
=
H
[
0
];
negative
=
H
[
1
];
if
((
1
==
largest
&&
(
K
<=
positive
||
dimension
-
K
+
1
<=
negative
))
||
(
0
==
largest
&&
(
K
<=
negative
||
dimension
-
K
+
1
<=
positive
)))
{
auto
KK
=
K
;
if
(
1
==
largest
)
{
if
(
KK
>
positive
)
{
KK
=
dimension
-
KK
+
1
;
sign
=
(
T
)
-
1
;
}
}
else
{
if
(
KK
>
negative
)
{
KK
=
dimension
-
KK
+
1
;
}
else
{
sign
=
(
T
)
-
1
;
}
}
__syncthreads
();
#pragma unroll
for
(
int64_t
byte
=
sizeof
(
T
)
-
1
;
byte
>
-
1
;
--
byte
)
{
if
(
tid
<
256
)
H
[
tid
]
=
0
;
__syncthreads
();
auto
skip
=
8
*
byte
,
prev_skip
=
8
*
(
byte
+
1
);
for
(
int64_t
x_i
=
tid
;
x_i
<
dimension
;
x_i
+=
blockDim
.
x
)
{
T
x
=
sign
*
X
[
FROM
(
x_i
)];
if
(
x
>
(
T
)
0
&&
(
byte
==
sizeof
(
T
)
-
1
||
SamePrefix
(
&
x
,
&
Kth
,
prev_skip
)))
{
atomicAdd
(
&
H
[
Radix
(
&
x
,
skip
)],
1
);
}
}
__syncthreads
();
for
(
int64_t
radix
=
255
;
radix
>
0
;
--
radix
)
{
if
(
H
[
radix
]
<
KK
)
{
KK
-=
H
[
radix
];
}
else
{
SetByte
(
&
Kth
,
radix
<<
skip
);
break
;
}
}
__syncthreads
();
}
Kth
*=
sign
;
}
uint32_t
superior
=
0
,
equal
=
0
;
for
(
int64_t
x_i
=
tid
;
x_i
<
dimension
;
x_i
+=
blockDim
.
x
)
{
auto
x
=
X
[
FROM
(
x_i
)];
if
((
1
==
largest
&&
x
>
Kth
)
||
(
0
==
largest
&&
x
<
Kth
))
{
++
superior
;
}
else
if
(
Equal
(
x
,
Kth
))
{
++
equal
;
}
}
__syncthreads
();
auto
all_superior
=
superior
;
all_superior
=
BlockReduce
(
temp_storage
.
reduce
).
Sum
(
all_superior
);
if
(
0
==
tid
)
{
H
[
0
]
=
all_superior
;
}
__syncthreads
();
all_superior
=
H
[
0
];
BlockScan
(
temp_storage
.
scan
).
ExclusiveSum
(
superior
,
superior
);
__syncthreads
();
BlockScan
(
temp_storage
.
scan
).
ExclusiveSum
(
equal
,
equal
);
__syncthreads
();
auto
equal_quota
=
K
-
all_superior
-
equal
;
auto
output_i
=
superior
+
LESS
(
K
-
all_superior
,
equal
);
for
(
int64_t
x_i
=
tid
;
x_i
<
dimension
;
x_i
+=
blockDim
.
x
)
{
auto
x
=
X
[
FROM
(
x_i
)];
if
((
1
==
largest
&&
x
>
Kth
)
||
(
0
==
largest
&&
x
<
Kth
))
{
auto
to_i
=
TO
(
output_i
);
V
[
to_i
]
=
x
;
I
[
to_i
]
=
x_i
;
++
output_i
;
}
else
if
(
Equal
(
x
,
Kth
)
&&
equal_quota
>
0
)
{
auto
to_i
=
TO
(
output_i
);
V
[
to_i
]
=
x
;
I
[
to_i
]
=
x_i
;
++
output_i
;
--
equal_quota
;
}
}
__syncthreads
();
if
(
1
==
sorted
)
{
T
keys
[
KPT
];
int64_t
vals
[
KPT
];
for
(
int64_t
k_i
=
tid
,
k_c
=
0
;
k_c
<
KPT
;
k_i
+=
blockDim
.
x
,
++
k_c
)
{
if
(
k_i
<
K
)
{
auto
to_i
=
TO
(
k_i
);
keys
[
k_c
]
=
V
[
to_i
];
vals
[
k_c
]
=
I
[
to_i
];
}
else
{
if
(
1
==
largest
)
{
keys
[
k_c
]
=
type_min
;
}
else
{
keys
[
k_c
]
=
type_max
;
}
}
}
__syncthreads
();
if
(
1
==
largest
)
{
BlockRadixSort
(
temp_storage
.
sort
).
SortDescending
(
keys
,
vals
);
}
else
{
BlockRadixSort
(
temp_storage
.
sort
).
Sort
(
keys
,
vals
);
}
__syncthreads
();
#pragma unroll
for
(
int64_t
k_c
=
0
;
k_c
<
KPT
;
++
k_c
)
{
auto
k_i
=
tid
*
KPT
+
k_c
;
if
(
k_i
<
K
)
{
auto
to_i
=
TO
(
k_i
);
V
[
to_i
]
=
keys
[
k_c
];
I
[
to_i
]
=
vals
[
k_c
];
}
}
}
}
template
<
typename
T
>
__global__
void
FillInput
(
const
T
*
input_x
,
T
*
output_v
,
int64_t
*
output_i
,
const
TArray
<
int64_t
>
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
offset
,
int64_t
dimension
)
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT
(
id
,
dimension
);
auto
left
=
offset
/
(
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
])
*
elem_nums
[
axis
];
auto
right
=
axis
==
size
-
1
?
0
:
offset
%
elem_nums
[
axis
+
1
];
auto
input_offset
=
left
+
id
*
(
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
])
+
right
;
output_v
[
id
]
=
input_x
[
input_offset
];
output_i
[
id
]
=
id
;
}
template
<
typename
T
>
__global__
void
FillOutput
(
const
T
*
input_v
,
const
int64_t
*
input_i
,
T
*
output_v
,
int64_t
*
output_i
,
const
TArray
<
int64_t
>
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
offset
,
int64_t
dimension
)
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT
(
id
,
K
);
auto
left
=
offset
/
(
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
])
*
elem_nums
[
axis
]
*
K
/
dimension
;
auto
right
=
axis
==
size
-
1
?
0
:
offset
%
elem_nums
[
axis
+
1
];
auto
output_offset
=
left
+
id
*
(
axis
==
size
-
1
?
1
:
elem_nums
[
axis
+
1
])
+
right
;
output_v
[
output_offset
]
=
input_v
[
id
];
output_i
[
output_offset
]
=
input_i
[
id
];
}
// template is used to avoid linking issue, since __global__ function cannot be inline-ed
template
<
typename
T
>
__global__
void
ExcludeOutput
(
T
*
output_i
,
T
K
,
T
dimension
)
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT
(
id
,
dimension
);
if
(
id
>=
K
)
{
output_i
[
id
]
=
dimension
;
}
}
template
<
typename
T
>
Status
TopKImpl
(
const
RocmKernel
*
kernel
,
hipStream_t
stream
,
const
T
*
input_x
,
T
*
output_v
,
int64_t
*
output_i
,
const
TArray
<
int64_t
>&
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
largest
,
int64_t
sorted
,
int64_t
N
,
int64_t
dimension
)
{
typedef
typename
ToHipType
<
T
>::
MappedType
HipT
;
const
HipT
*
input_x_ptr
=
reinterpret_cast
<
const
HipT
*>
(
input_x
);
HipT
*
output_v_ptr
=
reinterpret_cast
<
HipT
*>
(
output_v
);
auto
aligned_K
=
ALIGN
(
K
);
auto
aligned_dimension
=
ALIGN
(
dimension
);
if
(
aligned_dimension
<=
GridDim
::
maxThreadsPerBlock
)
{
BitonicTopK
<
HipT
><<<
N
,
GridDim
::
maxThreadsPerBlock
,
aligned_dimension
*
sizeof
(
KV
<
HipT
>
),
stream
>>>
(
input_x_ptr
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
aligned_K
,
largest
,
sorted
,
dimension
,
aligned_dimension
,
NumericLimits
<
T
>::
Min
(),
NumericLimits
<
T
>::
Max
());
}
else
if
(
K
<=
BT
*
16
||
0
==
sorted
)
{
auto
XPT
=
static_cast
<
int64_t
>
(
ceil
(
static_cast
<
double
>
(
dimension
)
/
GridDim
::
maxThreadsPerBlock
));
if
(
BT
*
2
>=
K
||
0
==
sorted
)
{
RadixTopK
<
HipT
,
BT
,
2
><<<
N
,
BT
,
256
*
sizeof
(
uint32_t
),
stream
>>>
(
input_x_ptr
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
largest
,
sorted
,
dimension
,
XPT
,
NumericLimits
<
T
>::
Min
(),
NumericLimits
<
T
>::
Max
());
}
else
if
(
BT
*
4
>=
K
)
{
RadixTopK
<
HipT
,
BT
,
4
><<<
N
,
BT
,
256
*
sizeof
(
uint32_t
),
stream
>>>
(
input_x_ptr
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
largest
,
sorted
,
dimension
,
XPT
,
NumericLimits
<
T
>::
Min
(),
NumericLimits
<
T
>::
Max
());
}
else
if
(
BT
*
8
>=
K
)
{
RadixTopK
<
HipT
,
BT
,
8
><<<
N
,
BT
,
256
*
sizeof
(
uint32_t
),
stream
>>>
(
input_x_ptr
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
largest
,
sorted
,
dimension
,
XPT
,
NumericLimits
<
T
>::
Min
(),
NumericLimits
<
T
>::
Max
());
}
else
{
RadixTopK
<
HipT
,
BT
,
16
><<<
N
,
BT
,
256
*
sizeof
(
uint32_t
),
stream
>>>
(
input_x_ptr
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
largest
,
sorted
,
dimension
,
XPT
,
NumericLimits
<
T
>::
Min
(),
NumericLimits
<
T
>::
Max
());
}
}
else
{
auto
input_key_buffer
=
kernel
->
GetScratchBuffer
<
HipT
>
(
dimension
);
auto
output_key_buffer
=
kernel
->
GetScratchBuffer
<
HipT
>
(
dimension
);
auto
input_value_buffer
=
kernel
->
GetScratchBuffer
<
int64_t
>
(
dimension
);
auto
output_value_buffer
=
kernel
->
GetScratchBuffer
<
int64_t
>
(
dimension
);
auto
*
input_key
=
input_key_buffer
.
get
();
auto
*
output_key
=
output_key_buffer
.
get
();
auto
*
input_value
=
input_value_buffer
.
get
();
auto
*
output_value
=
output_value_buffer
.
get
();
size_t
temp_bytes
=
0
;
HIP_RETURN_IF_ERROR
(
hipcub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
temp_bytes
,
input_key
,
output_key
,
input_value
,
output_value
,
dimension
,
0
,
sizeof
(
T
)
*
8
,
stream
));
auto
temp_storage_buffer
=
kernel
->
GetScratchBuffer
<
char
>
(
temp_bytes
);
auto
*
temp_storage
=
temp_storage_buffer
.
get
();
auto
blocks_per_grid_D
=
(
int
)(
ceil
(
static_cast
<
float
>
(
dimension
)
/
BT
));
auto
blocks_per_grid_K
=
(
int
)(
ceil
(
static_cast
<
float
>
(
K
)
/
BT
));
for
(
int64_t
i
=
0
;
i
<
N
;
i
++
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
FillInput
<
HipT
>
),
blocks_per_grid_D
,
BT
,
0
,
stream
,
input_x_ptr
,
input_key
,
input_value
,
elem_nums
,
size
,
axis
,
K
,
i
,
dimension
);
HIP_RETURN_IF_ERROR
(
1
==
largest
?
hipcub
::
DeviceRadixSort
::
SortPairsDescending
(
temp_storage
,
temp_bytes
,
input_key
,
output_key
,
input_value
,
output_value
,
dimension
,
0
,
sizeof
(
T
)
*
8
,
stream
)
:
hipcub
::
DeviceRadixSort
::
SortPairs
(
temp_storage
,
temp_bytes
,
input_key
,
output_key
,
input_value
,
output_value
,
dimension
,
0
,
sizeof
(
T
)
*
8
,
stream
));
if
(
1
==
sorted
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
FillOutput
<
HipT
>
),
blocks_per_grid_K
,
BT
,
0
,
stream
,
output_key
,
output_value
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
i
,
dimension
);
}
else
{
//reorder by ascending index
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ExcludeOutput
<
int64_t
>
),
blocks_per_grid_D
,
BT
,
0
,
stream
,
output_value
,
K
,
dimension
);
HIP_RETURN_IF_ERROR
(
hipcub
::
DeviceRadixSort
::
SortPairs
(
temp_storage
,
temp_bytes
,
output_value
,
input_value
,
output_key
,
input_key
,
dimension
,
0
,
sizeof
(
T
)
*
8
,
stream
));
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
FillOutput
<
HipT
>
),
blocks_per_grid_K
,
BT
,
0
,
stream
,
input_key
,
input_value
,
output_v_ptr
,
output_i
,
elem_nums
,
size
,
axis
,
K
,
i
,
dimension
);
}
}
}
return
Status
::
OK
();
}
#define TOPKIMPLE(T) template Status TopKImpl<T>(const RocmKernel* kernel, \
hipStream_t stream, \
const T* input_x, \
T* output_v, \
int64_t* output_i, \
const TArray<int64_t>& elem_nums, \
size_t size, \
int32_t axis, \
int64_t K, \
int64_t largest, \
int64_t sorted, \
int64_t N, \
int64_t dimension)
// This file is causing excessive long compilation time in ROCm EP. Split all those compilation into multiple
// translation units to speed it up.
TOPKIMPLE
(
TOPK_IMPL_TYPE
);
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/common/common.h"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
T
>
Status
TopKImpl
(
const
RocmKernel
*
kernel
,
hipStream_t
stream
,
const
T
*
input_x
,
T
*
output_v
,
int64_t
*
output_i
,
const
TArray
<
int64_t
>&
elem_nums
,
size_t
size
,
int32_t
axis
,
int64_t
K
,
int64_t
largest
,
int64_t
sorted
,
int64_t
N
,
int64_t
dimension
);
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f16.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE MLFloat16
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f32.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE float
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_f64.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE double
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i16.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int16_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i32.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int32_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i64.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int64_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_i8.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int8_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u16.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint16_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u32.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint32_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u64.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint64_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/topk_impl_u8.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint8_t
#include "topk_impl.cuh"
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "unary_elementwise_ops.h"
#include "unary_elementwise_ops_impl.h"
namespace
onnxruntime
{
namespace
rocm
{
Status
UnaryElementwise
::
Prepare
(
OpKernelContext
*
context
,
UnaryElementwisePreparation
*
p
)
const
{
p
->
input_tensor
=
context
->
Input
<
Tensor
>
(
0
);
p
->
output_tensor
=
context
->
Output
(
0
,
p
->
input_tensor
->
Shape
());
return
Status
::
OK
();
}
#define UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_REGISTER_KERNEL(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
// 'Not' only has a 'T' type constraint. The other logical ops have T and T1.
#define UNARY_ELEMENTWISE_LOGICALOP_NOT_REGISTER_KERNEL_TYPED(ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Not, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Not<T>);
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Impl_##x( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
#define UNARY_OP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_NOT_TYPED(ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_NOT_REGISTER_KERNEL_TYPED(ver, T) \
UNARY_ELEMENTWISE_COMPUTE(Not, T)
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define UNARY_OP_VERSIONED_HFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, double)
#define UNARY_OP_VERSIONED_CSILHFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int8_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int16_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
UNARY_OP_VERSIONED_HFD(name, startver, endver)
#define UNARY_OP_VERSIONED_BWUZCSILHFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint8_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint16_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
UNARY_OP_VERSIONED_CSILHFD(name, startver, endver)
#define UNARY_OP_HFD(name, ver) \
UNARY_OP_TYPED(name, ver, MLFloat16) \
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)
#define UNARY_OP_CSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, int8_t) \
UNARY_OP_TYPED(name, ver, int16_t) \
UNARY_OP_TYPED(name, ver, int32_t) \
UNARY_OP_TYPED(name, ver, int64_t) \
UNARY_OP_HFD(name, ver)
#define UNARY_OP_BWUZCSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, uint8_t) \
UNARY_OP_TYPED(name, ver, uint16_t) \
UNARY_OP_TYPED(name, ver, uint32_t) \
UNARY_OP_TYPED(name, ver, uint64_t) \
UNARY_OP_CSILHFD(name, ver)
UNARY_OP_VERSIONED_BWUZCSILHFD
(
Abs
,
6
,
12
)
UNARY_OP_VERSIONED_CSILHFD
(
Neg
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Floor
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Ceil
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Reciprocal
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Sqrt
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Log
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Exp
,
6
,
12
)
UNARY_OP_VERSIONED_HFD
(
Erf
,
9
,
12
)
UNARY_OP_BWUZCSILHFD
(
Abs
,
13
)
UNARY_OP_CSILHFD
(
Neg
,
13
)
UNARY_OP_HFD
(
Floor
,
13
)
UNARY_OP_HFD
(
Ceil
,
13
)
UNARY_OP_HFD
(
Reciprocal
,
13
)
UNARY_OP_HFD
(
Sqrt
,
13
)
UNARY_OP_HFD
(
Log
,
13
)
UNARY_OP_HFD
(
Exp
,
13
)
UNARY_OP_HFD
(
Erf
,
13
)
UNARY_LOGICALOP_NOT_TYPED
(
1
,
bool
)
UNARY_OP_HFD
(
Round
,
11
)
UNARY_OP_HFD
(
Cos
,
7
)
UNARY_OP_HFD
(
Sin
,
7
)
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
rocm
{
struct
UnaryElementwisePreparation
{
const
Tensor
*
input_tensor
=
nullptr
;
Tensor
*
output_tensor
=
nullptr
;
};
class
UnaryElementwise
:
public
RocmKernel
{
protected:
UnaryElementwise
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
)
const
override
{
return
Status
(
common
::
ONNXRUNTIME
,
common
::
FAIL
);
// should not reach here
}
Status
Prepare
(
OpKernelContext
*
context
,
UnaryElementwisePreparation
*
p
)
const
;
};
template
<
typename
T
>
class
Abs
final
:
public
UnaryElementwise
{
public:
Abs
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Neg
final
:
public
UnaryElementwise
{
public:
Neg
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Floor
final
:
public
UnaryElementwise
{
public:
Floor
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Ceil
final
:
public
UnaryElementwise
{
public:
Ceil
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Reciprocal
final
:
public
UnaryElementwise
{
public:
Reciprocal
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Sqrt
final
:
public
UnaryElementwise
{
public:
Sqrt
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Log
final
:
public
UnaryElementwise
{
public:
Log
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Exp
final
:
public
UnaryElementwise
{
public:
Exp
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Erf
final
:
public
UnaryElementwise
{
public:
Erf
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Not
final
:
public
UnaryElementwise
{
public:
Not
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Round
final
:
public
UnaryElementwise
{
public:
Round
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Sin
final
:
public
UnaryElementwise
{
public:
Sin
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Cos
final
:
public
UnaryElementwise
{
public:
Cos
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops_impl.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "unary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
namespace
onnxruntime
{
namespace
rocm
{
#define OP(name, expr) \
template <typename T> \
struct OP_##name { \
__device__ __inline__ T operator()(const T& a) const { \
return expr; \
} \
};
#define UNARY_ELEMENTWISE_IMPL(name) \
UNARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
OP_##name<T>(), \
count); \
}
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, size_t count);
#define UNARY_OP_NAME_EXPR(name, expr) \
OP(name, expr) \
UNARY_ELEMENTWISE_IMPL(name)
UNARY_OPS
()
#undef UNARY_OP_NAME_EXPR
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int8_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(name)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD
(
Abs
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD
(
Neg
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Floor
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Ceil
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Reciprocal
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Sqrt
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB
(
Log
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB
(
Exp
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Erf
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Round
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Sin
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD
(
Cos
)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL
(
Not
,
bool
)
// When casting, half needs to be converted via float type from most other types
template
<
typename
T
>
struct
ViaTypeMap
{
typedef
T
ViaT
;
};
template
<
>
struct
ViaTypeMap
<
half
>
{
typedef
float
ViaT
;
};
template
<
>
struct
ViaTypeMap
<
BFloat16
>
{
typedef
float
ViaT
;
};
template
<
typename
InT
,
typename
OutT
>
struct
OP_Cast
{
__device__
__inline__
OutT
operator
()(
const
InT
&
a
)
const
{
const
bool
any_float16
=
std
::
is_same
<
half
,
InT
>::
value
||
std
::
is_same
<
half
,
OutT
>::
value
;
const
bool
any_bf16
=
std
::
is_same
<
BFloat16
,
InT
>::
value
||
std
::
is_same
<
BFloat16
,
OutT
>::
value
;
typedef
typename
std
::
conditional
<
any_bf16
,
BFloat16
,
OutT
>::
type
T1
;
typedef
typename
std
::
conditional
<
any_float16
,
half
,
T1
>::
type
T
;
typedef
typename
ViaTypeMap
<
T
>::
ViaT
ViaT
;
return
(
OutT
)((
ViaT
)
a
);
}
};
template
<
typename
InT
,
typename
OutT
>
void
Impl_Cast
(
hipStream_t
stream
,
const
InT
*
input_data
,
OutT
*
output_data
,
size_t
count
)
{
UnaryElementWiseImpl
(
stream
,
input_data
,
output_data
,
OP_Cast
<
InT
,
OutT
>
(),
count
);
}
#define SPECIALIZED_CAST_IMPL2(InT, OutT) \
template void Impl_Cast<InT, OutT>(hipStream_t stream, const InT* input_data, OutT* output_data, size_t count);
#define SPECIALIZED_CAST_FROM(T) \
SPECIALIZED_CAST_IMPL2(T, half) \
SPECIALIZED_CAST_IMPL2(T, float) \
SPECIALIZED_CAST_IMPL2(T, double) \
SPECIALIZED_CAST_IMPL2(T, int8_t) \
SPECIALIZED_CAST_IMPL2(T, int16_t) \
SPECIALIZED_CAST_IMPL2(T, int32_t) \
SPECIALIZED_CAST_IMPL2(T, int64_t) \
SPECIALIZED_CAST_IMPL2(T, uint8_t) \
SPECIALIZED_CAST_IMPL2(T, uint16_t) \
SPECIALIZED_CAST_IMPL2(T, uint32_t) \
SPECIALIZED_CAST_IMPL2(T, uint64_t) \
SPECIALIZED_CAST_IMPL2(T, bool) \
SPECIALIZED_CAST_IMPL2(T, BFloat16)
SPECIALIZED_CAST_FROM
(
half
)
SPECIALIZED_CAST_FROM
(
float
)
SPECIALIZED_CAST_FROM
(
double
)
SPECIALIZED_CAST_FROM
(
int8_t
)
SPECIALIZED_CAST_FROM
(
int16_t
)
SPECIALIZED_CAST_FROM
(
int32_t
)
SPECIALIZED_CAST_FROM
(
int64_t
)
SPECIALIZED_CAST_FROM
(
uint8_t
)
SPECIALIZED_CAST_FROM
(
uint16_t
)
SPECIALIZED_CAST_FROM
(
uint32_t
)
SPECIALIZED_CAST_FROM
(
uint64_t
)
SPECIALIZED_CAST_FROM
(
bool
)
SPECIALIZED_CAST_FROM
(
BFloat16
)
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/unary_elementwise_ops_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace
onnxruntime
{
namespace
rocm
{
// This macro simplifies coding to add a new op with following steps:
// 1. Add a new entry in UNARY_OPS() list
// 2. (optional) Define templated single element operator in unary_elementwise_ops_impl.cu
// 3. (optional) Implement specialized single element operator
// 4. Add op kernel class definition in unary_elementwise_ops.h
// 5. Add op kernel registration and compute specialization in unary_elementwise_ops.cc
#define UNARY_OPS() \
UNARY_OP_NAME_EXPR(Abs, _Abs(a)) \
UNARY_OP_NAME_EXPR(Neg, -a) \
UNARY_OP_NAME_EXPR(Ceil, _Ceil(a)) \
UNARY_OP_NAME_EXPR(Floor, _Floor(a)) \
UNARY_OP_NAME_EXPR(Reciprocal, T(1) / a) \
UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
UNARY_OP_NAME_EXPR(Cos, _Cos(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
hipStream_t stream, \
const T* input_data, \
T* output_data, \
size_t count)
#define UNARY_OP_NAME_EXPR(name, expr) UNARY_ELEMENTWISE_IMPL_DECLARATION(name);
UNARY_OPS
()
#undef UNARY_OP_NAME_EXPR
template
<
typename
InT
,
typename
OutT
>
void
Impl_Cast
(
hipStream_t
stream
,
const
InT
*
input_data
,
OutT
*
output_data
,
size_t
count
);
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/variadic_elementwise_ops.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/math/variadic_elementwise_ops.h"
#include <cassert>
#include <algorithm>
#include "core/framework/data_types_internal.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/variadic_elementwise_ops_impl.h"
#include "core/providers/rocm/math/variadic_elementwise_ops_tags.h"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
VariadicElementwiseOpTag
,
typename
...
SupportedElementTypes
>
template
<
typename
T
>
Status
VariadicElementwiseOp
<
VariadicElementwiseOpTag
,
SupportedElementTypes
...
>::
NoBroadcastBatchImplDispatchTarget
<
T
>::
operator
()(
hipStream_t
stream
,
const
InputTensorVector
&
inputs
,
Tensor
&
output
)
const
{
using
HipT
=
typename
ToHipType
<
T
>::
MappedType
;
size_t
input_count
=
inputs
.
size
();
assert
(
input_count
>
1
);
size_t
index
=
std
::
min
(
input_count
,
static_cast
<
size_t
>
(
k_max_input_batch_size
));
InputBatchArray
<
HipT
>
input_data_batch
{
static_cast
<
int32_t
>
(
index
)};
for
(
size_t
i
=
0
;
i
<
index
;
++
i
)
{
input_data_batch
[
static_cast
<
int32_t
>
(
i
)]
=
reinterpret_cast
<
const
HipT
*>
(
inputs
[
i
].
get
().
Data
<
T
>
());
}
HipT
*
output_data
=
reinterpret_cast
<
HipT
*>
(
output
.
MutableData
<
T
>
());
Impl_NoBroadcastInputBatch
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
input_data_batch
,
output_data
,
output
.
Shape
().
Size
());
while
(
index
<
input_count
)
{
size_t
left_count
=
input_count
-
index
+
1
;
size_t
batch
=
std
::
min
(
left_count
,
static_cast
<
size_t
>
(
k_max_input_batch_size
));
// Special case for 2 inputs left.
if
(
batch
==
2
)
{
BinaryElementwisePreparation
prepare
;
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
&
output
,
&
inputs
[
input_count
-
1
].
get
(),
&
output
,
&
prepare
));
Impl_General
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
HipT
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
// Must be the last.
break
;
}
InputBatchArray
<
HipT
>
left_input_data_batch
{
static_cast
<
int32_t
>
(
batch
)};
left_input_data_batch
[
0
]
=
reinterpret_cast
<
const
HipT
*>
(
output
.
Data
<
T
>
());
for
(
size_t
i
=
1
;
i
<
batch
;
++
i
)
{
left_input_data_batch
[
static_cast
<
int32_t
>
(
i
)]
=
reinterpret_cast
<
const
HipT
*>
(
inputs
[
index
].
get
().
Data
<
T
>
());
index
++
;
}
Impl_NoBroadcastInputBatch
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
left_input_data_batch
,
output_data
,
output
.
Shape
().
Size
());
}
return
Status
::
OK
();
}
// special case for 2 tensors to avoid memset zero
template
<
typename
VariadicElementwiseOpTag
,
typename
...
SupportedElementTypes
>
template
<
typename
T
>
Status
VariadicElementwiseOp
<
VariadicElementwiseOpTag
,
SupportedElementTypes
...
>::
BinaryImplDispatchTarget
<
T
>::
operator
()(
hipStream_t
stream
,
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
&
output
)
const
{
using
HipT
=
typename
ToHipType
<
T
>::
MappedType
;
BinaryElementwisePreparation
prepare
;
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
&
lhs
,
&
rhs
,
&
output
,
&
prepare
));
Impl_General
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
HipT
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
return
Status
::
OK
();
}
// for more than 2 inputs, we need to accumulate into output tensor, as the shape from input0 + input1 might be different from output shape
template
<
typename
VariadicElementwiseOpTag
,
typename
...
SupportedElementTypes
>
template
<
typename
T
>
Status
VariadicElementwiseOp
<
VariadicElementwiseOpTag
,
SupportedElementTypes
...
>::
GeneralImplDispatchTarget
<
T
>::
operator
()(
hipStream_t
stream
,
const
InputTensorVector
&
inputs
,
Tensor
&
output
)
const
{
assert
(
inputs
.
size
()
>
1
);
using
HipT
=
typename
ToHipType
<
T
>::
MappedType
;
// If there is any input having the same shape with output, we don't need the memset.
size_t
index_of_same_shape
=
0
;
for
(;
index_of_same_shape
<
inputs
.
size
();
index_of_same_shape
++
)
{
if
(
inputs
[
index_of_same_shape
].
get
().
Shape
()
==
output
.
Shape
())
{
break
;
}
}
BinaryElementwisePreparation
prepare
;
// No input has same shape of output, memset the output, and add the 1st input as initialization.
if
(
index_of_same_shape
==
inputs
.
size
())
{
HIP_RETURN_IF_ERROR
(
hipMemsetAsync
(
output
.
MutableDataRaw
(),
0
,
output
.
SizeInBytes
(),
stream
));
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
&
output
,
&
inputs
[
0
].
get
(),
&
output
,
&
prepare
));
Impl_Add
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
HipT
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
}
else
{
// First operation is between input[0] and input[index_of_same_shape] if index_of_same_shape is not 0.
size_t
index
=
index_of_same_shape
==
0
?
1
:
0
;
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
&
inputs
[
index_of_same_shape
].
get
(),
&
inputs
[
index
].
get
(),
&
output
,
&
prepare
));
Impl_General
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
HipT
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
}
for
(
size_t
index
=
1
;
index
<
inputs
.
size
();
index
++
)
{
// If index_of_same_shape is 0, we already handle the 1st and 2nd inputs.
if
(
index
==
index_of_same_shape
||
(
index_of_same_shape
==
0
&&
index
==
1
))
{
continue
;
}
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
&
output
,
&
inputs
[
index
].
get
(),
&
output
,
&
prepare
));
Impl_General
<
HipT
,
VariadicElementwiseOpTag
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
HipT
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
}
return
Status
::
OK
();
}
template
<
typename
VariadicElementwiseOpTag
,
typename
...
SupportedElementTypes
>
Status
VariadicElementwiseOp
<
VariadicElementwiseOpTag
,
SupportedElementTypes
...
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
const
auto
&
node
=
Node
();
const
auto
&
node_name
=
node
.
Name
();
auto
input_count
=
node
.
InputArgCount
().
front
();
ORT_RETURN_IF_NOT
(
input_count
>=
1
,
"Must have 1 or more inputs"
);
const
InputTensorVector
input_tensors
=
[
&
context
,
input_count
]()
{
InputTensorVector
result
{};
result
.
reserve
(
input_count
);
for
(
int
i
=
0
;
i
<
input_count
;
++
i
)
{
const
auto
&
tensor
=
context
->
RequiredInput
<
Tensor
>
(
i
);
result
.
push_back
(
std
::
cref
(
tensor
));
}
return
result
;
}();
const
auto
&
first_input_tensor
=
input_tensors
[
0
].
get
();
// special case for 1 input
if
(
input_count
==
1
)
{
auto
&
output_tensor
=
context
->
RequiredOutput
(
0
,
first_input_tensor
.
Shape
());
if
(
first_input_tensor
.
DataRaw
()
!=
output_tensor
.
DataRaw
())
{
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
output_tensor
.
MutableDataRaw
(),
first_input_tensor
.
DataRaw
(),
first_input_tensor
.
SizeInBytes
(),
hipMemcpyDeviceToDevice
,
Stream
()));
}
return
Status
::
OK
();
}
const
auto
element_type
=
first_input_tensor
.
GetElementType
();
utils
::
MLTypeCallDispatcher
<
SupportedElementTypes
...
>
dispatcher
(
element_type
);
// Special case for no broadcasting.
if
(
std
::
all_of
(
input_tensors
.
begin
()
+
1
,
input_tensors
.
end
(),
[
&
first_input_tensor
](
InputTensorVector
::
value_type
t
)
{
return
first_input_tensor
.
Shape
()
==
t
.
get
().
Shape
();
}))
{
auto
&
output_tensor
=
context
->
RequiredOutput
(
0
,
first_input_tensor
.
Shape
());
// special case for no broadcasting and 2 inputs
if
(
input_count
==
2
)
{
return
dispatcher
.
template
InvokeRet
<
Status
,
BinaryImplDispatchTarget
>(
Stream
(),
input_tensors
[
0
],
input_tensors
[
1
],
output_tensor
);
}
return
dispatcher
.
template
InvokeRet
<
Status
,
NoBroadcastBatchImplDispatchTarget
>(
Stream
(),
input_tensors
,
output_tensor
);
}
// compute output shape first, using broadcast rule
TensorShape
output_shape
;
TensorShape
previous_output_shape
=
first_input_tensor
.
Shape
();
for
(
int
index
=
1
;
index
<
input_count
;
index
++
)
{
ORT_RETURN_IF_ERROR
(
ComputeOutputShape
(
node_name
,
previous_output_shape
,
input_tensors
[
index
].
get
().
Shape
(),
output_shape
));
previous_output_shape
=
output_shape
;
}
Tensor
&
output_tensor
=
context
->
RequiredOutput
(
0
,
output_shape
);
// special case for 2 inputs
if
(
input_count
==
2
)
{
return
dispatcher
.
template
InvokeRet
<
Status
,
BinaryImplDispatchTarget
>(
Stream
(),
input_tensors
[
0
],
input_tensors
[
1
],
output_tensor
);
}
// general case for more than 2 inputs
return
dispatcher
.
template
InvokeRet
<
Status
,
GeneralImplDispatchTarget
>(
Stream
(),
input_tensors
,
output_tensor
);
}
namespace
{
using
SumOp
=
VariadicElementwiseOp
<
variadic_elementwise_ops
::
Sum
,
MLFloat16
,
float
,
double
,
BFloat16
>
;
using
MinOp
=
VariadicElementwiseOp
<
variadic_elementwise_ops
::
Min
,
uint32_t
,
uint64_t
,
int32_t
,
int64_t
,
MLFloat16
,
float
,
double
,
BFloat16
>
;
using
MaxOp
=
VariadicElementwiseOp
<
variadic_elementwise_ops
::
Max
,
uint32_t
,
uint64_t
,
int32_t
,
int64_t
,
MLFloat16
,
float
,
double
,
BFloat16
>
;
}
// namespace
// kernel registration
#define REGISTER_KERNEL(name, impl_class, version, datatypes) \
ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, version, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<datatypes>()), \
impl_class)
#define REGISTER_VERSIONED_KERNEL(name, impl_class, start_version, end_version, datatypes) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
name, kOnnxDomain, start_version, end_version, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<datatypes>()), impl_class)
#define UZILHFD_TYPES uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16
#define HFD_TYPES MLFloat16, float, double, BFloat16
REGISTER_KERNEL
(
Sum
,
SumOp
,
13
,
HFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Sum
,
SumOp
,
8
,
12
,
HFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Sum
,
SumOp
,
6
,
7
,
HFD_TYPES
)
REGISTER_KERNEL
(
Min
,
MinOp
,
13
,
UZILHFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Min
,
MinOp
,
12
,
12
,
UZILHFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Min
,
MinOp
,
6
,
11
,
HFD_TYPES
)
REGISTER_KERNEL
(
Max
,
MaxOp
,
13
,
UZILHFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Max
,
MaxOp
,
12
,
12
,
UZILHFD_TYPES
)
REGISTER_VERSIONED_KERNEL
(
Max
,
MaxOp
,
6
,
11
,
HFD_TYPES
)
#undef HFD_TYPES
#undef UZILHFD_TYPES
#undef REGISTER_VERSIONED_KERNEL
#undef REGISTER_KERNEL
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/variadic_elementwise_ops.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include <vector>
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
rocm
{
using
InputTensorVector
=
std
::
vector
<
std
::
reference_wrapper
<
const
Tensor
>>
;
template
<
typename
VariadicElementwiseOpTag
,
typename
...
SupportedElementTypes
>
class
VariadicElementwiseOp
:
public
RocmKernel
{
public:
VariadicElementwiseOp
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{}
private:
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
template
<
typename
T
>
struct
NoBroadcastBatchImplDispatchTarget
{
Status
operator
()(
hipStream_t
stream
,
const
InputTensorVector
&
inputs
,
Tensor
&
output
)
const
;
};
template
<
typename
T
>
struct
BinaryImplDispatchTarget
{
Status
operator
()(
hipStream_t
stream
,
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
&
output
)
const
;
};
template
<
typename
T
>
struct
GeneralImplDispatchTarget
{
Status
operator
()(
hipStream_t
stream
,
const
InputTensorVector
&
inputs
,
Tensor
&
output
)
const
;
};
};
}
// namespace rocm
}
// namespace onnxruntime
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment