Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bc9eac75
Commit
bc9eac75
authored
Mar 08, 2022
by
Shucai Xiao
Browse files
version that softmax half2 works
parent
23a18b2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
4 deletions
+6
-4
src/targets/gpu/device/mul.cpp
src/targets/gpu/device/mul.cpp
+1
-1
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+5
-3
No files found.
src/targets/gpu/device/mul.cpp
View file @
bc9eac75
...
@@ -50,7 +50,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
...
@@ -50,7 +50,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
}
}
else
else
{
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
+
y
;
});
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
*
y
;
});
}
}
}
}
...
...
src/targets/gpu/device/softmax.cpp
View file @
bc9eac75
...
@@ -114,28 +114,30 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
...
@@ -114,28 +114,30 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
{
{
__half
*
input
=
reinterpret_cast
<
__half
*>
(
data_in
);
__half
*
input
=
reinterpret_cast
<
__half
*>
(
data_in
);
__half
*
output
=
reinterpret_cast
<
__half
*>
(
data_out
);
__half
*
output
=
reinterpret_cast
<
__half
*>
(
data_out
);
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
extern
MIGRAPHX_DEVICE_SHARED
__half
buffer
[];
extern
MIGRAPHX_DEVICE_SHARED
__half
buffer
[];
__half
*
in_data_reduce
=
buffer
;
__half
*
in_data_reduce
=
buffer
;
__half
*
in_data
=
buffer
+
batch_item_num
;
__half
*
in_data
=
buffer
+
batch_item_num
;
int
start
=
tid
/
block_size
*
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
{
auto
d
=
input
[
i
+
start
];
auto
d
=
input
[
i
+
start
];
in_data
[
i
]
=
d
;
in_data
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
// printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
}
}
auto
batch_max
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
max
{});
auto
batch_max
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
max
{});
// printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
{
in_data
[
i
]
=
__float2half
(
::
exp
(
__half2float
(
in_data
[
i
])
-
__half2float
(
batch_max
)));
in_data
[
i
]
=
__float2half
(
::
exp
(
__half2float
(
in_data
[
i
])
-
__half2float
(
batch_max
)));
in_data_reduce
[
i
]
=
in_data
[
i
];
in_data_reduce
[
i
]
=
in_data
[
i
];
// printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
}
}
auto
batch_sum
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
sum
{});
auto
batch_sum
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
sum
{});
// printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment