Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
45da3115
Commit
45da3115
authored
Mar 07, 2022
by
Shucai Xiao
Browse files
backup code changes related to softmax
parent
ea656c84
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
166 additions
and
22 deletions
+166
-22
src/targets/gpu/device/gelu.cpp
src/targets/gpu/device/gelu.cpp
+2
-2
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+6
-4
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+158
-16
No files found.
src/targets/gpu/device/gelu.cpp
View file @
45da3115
...
@@ -59,8 +59,8 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
...
@@ -59,8 +59,8 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2
sqrt2
=
__float2half2_rn
(
M_SQRT1_2
);
__half2
sqrt2
=
__float2half2_rn
(
M_SQRT1_2
);
auto
x
=
__hmul2
(
sum
,
sqrt2
);
auto
x
=
__hmul2
(
sum
,
sqrt2
);
auto
f2
=
__half22float2
(
x
);
auto
f2
=
__half22float2
(
x
);
f2
.
x
=
::
erf
(
f2
.
x
);
f2
.
x
=
::
erf
f
(
f2
.
x
);
f2
.
y
=
::
erf
(
f2
.
y
);
f2
.
y
=
::
erf
f
(
f2
.
y
);
auto
h2
=
__floats2half2_rn
(
f2
.
x
,
f2
.
y
);
auto
h2
=
__floats2half2_rn
(
f2
.
x
,
f2
.
y
);
auto
one
=
__float2half2_rn
(
1.0
f
);
auto
one
=
__float2half2_rn
(
1.0
f
);
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
45da3115
...
@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
...
@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \
if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static
index_int
group_num_global
=
(
1
<<
20
);
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
...
@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
...
@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
...
@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
...
@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
...
@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
...
@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
...
src/targets/gpu/device/softmax.cpp
View file @
45da3115
...
@@ -6,12 +6,144 @@
...
@@ -6,12 +6,144 @@
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
struct
half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
__half2
x
,
__half2
y
)
const
{
return
__hadd2
(
x
,
y
);
}
};
inline
__device__
__half2
hmax2
(
__half2
x
,
__half2
y
)
{
auto
fx2
=
__half22float2
(
x
);
auto
fy2
=
__half22float2
(
y
);
auto
fx
=
fx2
.
x
>
fy2
.
x
?
fx2
.
x
:
fy2
.
x
;
auto
fy
=
fx2
.
y
>
fy2
.
y
?
fx2
.
y
:
fy2
.
y
;
return
__floats2half2_rn
(
fx
,
fy
);
}
struct
half2_max
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
__half2
x
,
__half2
y
)
const
{
return
hmax2
(
x
,
y
);
}
};
// in_data is in shared memory
template
<
class
Op
>
__device__
__half2
block_reduce
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
{
const
index_int
index
=
2
*
s
*
tid
;
if
(
index
+
s
<
batch_item_num
)
{
buffer
[
index
]
=
op
(
buffer
[
index
],
buffer
[
index
+
s
]);
}
__syncthreads
();
}
auto
lows2
=
__low2half2
(
buffer
[
0
]);
auto
highs2
=
__high2half2
(
buffer
[
0
]);
return
op
(
lows2
,
highs2
);
}
__global__
void
softmax_kernel
(
void
*
data_in
,
index_int
batch_item_num
,
index_int
block_size
,
void
*
data_out
)
{
__half2
*
input
=
reinterpret_cast
<
__half2
*>
(
data_in
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
batch_item_num
/=
2
;
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
tid
/
block_size
*
batch_item_num
;
for
(
int
i
=
tid
;
i
<
batch_item_num
;
i
+=
block_size
)
{
auto
d
=
input
[
i
+
start
];
in_data
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
}
auto
batch_max
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_max
{});
for
(
int
i
=
tid
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
h2exp
(
__hsub2
(
in_data
[
i
],
batch_max
));
in_data_reduce
[
i
]
=
in_data
[
i
];
}
auto
batch_sum
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
for
(
int
i
=
tid
;
i
<
batch_item_num
;
i
+=
block_size
)
{
output
[
i
+
start
]
=
__h2div
(
in_data
[
i
],
batch_sum
);
}
}
// in_data is in shared memory
template
<
class
Op
>
__device__
__half
block_reduce2
(
__half
*
data
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
{
const
index_int
index
=
2
*
s
*
tid
;
if
(
index
+
s
<
batch_item_num
)
{
data
[
index
]
=
op
(
data
[
index
],
data
[
index
+
s
]);
}
__syncthreads
();
}
return
data
[
0
];
}
__global__
void
softmax_kernel2
(
void
*
data_in
,
index_int
batch_item_num
,
index_int
block_size
,
void
*
data_out
)
{
__half
*
input
=
reinterpret_cast
<
__half
*>
(
data_in
);
__half
*
output
=
reinterpret_cast
<
__half
*>
(
data_out
);
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
extern
MIGRAPHX_DEVICE_SHARED
__half
buffer
[];
__half
*
in_data_reduce
=
buffer
;
__half
*
in_data
=
buffer
+
batch_item_num
;
int
start
=
tid
/
block_size
*
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
auto
d
=
input
[
i
+
start
];
in_data
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
}
auto
batch_max
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
max
{});
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__float2half
(
::
exp
(
__half2float
(
in_data
[
i
])
-
__half2float
(
batch_max
)));
in_data_reduce
[
i
]
=
in_data
[
i
];
}
auto
batch_sum
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
sum
{});
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
output
[
i
+
start
]
=
__float2half
(
__half2float
(
in_data
[
i
])
/
__half2float
(
batch_sum
));
}
}
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
{
auto
batch_lens
=
result
.
get_shape
().
lens
();
auto
batch_lens
=
result
.
get_shape
().
lens
();
...
@@ -27,25 +159,35 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -27,25 +159,35 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if
(
axis
==
batch_lens
.
size
()
-
1
)
if
(
axis
==
batch_lens
.
size
()
-
1
)
{
{
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
auto
in_type
=
result
.
get_shape
().
type
();
[
=
](
auto
i
,
auto
idx
)
__device__
{
if
(
in_type
==
shape
::
half_type
and
batch_item_num
<=
2048
)
auto
start_loc
=
i
/
block_size
*
batch_item_num
;
{
auto
batch_max
=
block_reduce
<
max_block_size
>
(
int
block_num
=
batch_shape
.
elements
();
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
int
shared_size
=
batch_item_num
*
2
*
result
.
get_shape
().
type_size
();
return
input
[
start_loc
+
j
];
softmax_kernel2
<<<
block_num
,
block_size
,
shared_size
,
stream
>>>
(
arg
.
data
(),
batch_item_num
,
block_size
,
result
.
data
());
});
}
else
{
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
start_loc
=
i
/
block_size
*
batch_item_num
;
auto
batch_max
=
block_reduce
<
max_block_size
>
(
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
return
input
[
start_loc
+
j
];
});
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
return
::
exp
(
to_hip_type
(
val
));
return
::
exp
(
to_hip_type
(
val
));
});
});
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
__device__
{
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
output
[
start_loc
+
j
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
output
[
start_loc
+
j
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
});
});
});
});
}
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment