Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1cadb2a1
Commit
1cadb2a1
authored
Sep 03, 2025
by
xgqdut2016
Browse files
issue/342: delete to
parent
0ecbe1d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
25 deletions
+9
-25
src/infiniop/devices/kunlun/kunlun_kernel_common.h
src/infiniop/devices/kunlun/kunlun_kernel_common.h
+0
-16
src/infiniop/ops/random_sample/kunlun/kernel.h
src/infiniop/ops/random_sample/kunlun/kernel.h
+9
-9
No files found.
src/infiniop/devices/kunlun/kunlun_kernel_common.h
View file @
1cadb2a1
...
...
@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
__builtin_memcpy
(
v
,
p
,
len
*
sizeof
(
T
));
}
/**
* @brief Convert data type. All data is in local memory
* @param v: input value
* @return output value
*/
template
<
typename
Tout
,
typename
Tin
>
__device__
inline
Tout
to
(
Tin
v
)
{
if
constexpr
(
std
::
is_same
<
Tin
,
half
>::
value
)
{
return
__half2float
(
v
);
}
else
if
constexpr
(
std
::
is_same
<
Tin
,
bfloat16_t
>::
value
)
{
return
__bfloat162float
(
v
);
}
else
{
return
static_cast
<
Tout
>
(
v
);
}
}
/**
* @brief atomicAdd for kunlun xpu
* @param ptr: pointer to shared memory
...
...
src/infiniop/ops/random_sample/kunlun/kernel.h
View file @
1cadb2a1
...
...
@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
__shared__
Tcompute
sum_
;
if
(
core_id
()
==
0
)
{
sum_
=
to
<
Tcompute
>
(
0.
f
);
sum_
=
Tcompute
(
0.
f
);
}
sync_cluster
();
...
...
@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
for
(
int
index
=
core_id
();
index
<
read_len
;
index
+=
BLOCK_SIZE
)
{
if
constexpr
(
std
::
is_same_v
<
Tval
,
half
>
)
{
y_sm
[
index
]
=
__float2half
(
exp
((
__half2float
(
x_sm
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
));
y_sm
[
index
]
=
__float2half
(
exp
((
__half2float
(
x_sm
[
index
])
-
float
(
max_value
))
/
temperature
));
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
bfloat16_t
>
)
{
y_sm
[
index
]
=
__float2bfloat16
(
exp
((
__bfloat162float
(
x_sm
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
));
y_sm
[
index
]
=
__float2bfloat16
(
exp
((
__bfloat162float
(
x_sm
[
index
])
-
float
(
max_value
))
/
temperature
));
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
float
>
)
{
y_sm
[
index
]
=
exp
((
x_sm
[
index
]
-
max_value
)
/
temperature
);
}
...
...
@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM
(
values_global
+
r
*
buf_size
,
values_local
,
read_len
*
sizeof
(
Tval
));
for
(
int
index
=
0
;
index
<
read_len
;
index
++
)
{
if
constexpr
(
std
::
is_same_v
<
Tval
,
float
>
)
{
cumsum
+=
exp
((
values_local
[
index
]
-
max_value
)
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
values_local
[
index
]
-
max_value
)
/
temperature
)
/
float
(
all_sum
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
bfloat16_t
>
)
{
cumsum
+=
exp
((
to
<
float
>
(
values_local
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
float
(
values_local
[
index
])
-
float
(
max_value
))
/
temperature
)
/
float
(
all_sum
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
half
>
)
{
cumsum
+=
exp
((
to
<
float
>
(
values_local
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
float
(
values_local
[
index
])
-
float
(
max_value
))
/
temperature
)
/
float
(
all_sum
);
}
if
(
cumsum
>=
topp
)
{
end
=
r
*
buf_size
+
index
+
1
;
...
...
@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM
(
values_global
+
r
*
buf_size
,
values_local
,
read_len
*
sizeof
(
Tval
));
for
(
int
index
=
0
;
index
<
read_len
;
index
++
)
{
if
constexpr
(
std
::
is_same_v
<
Tval
,
float
>
)
{
cumsum
+=
exp
((
values_local
[
index
]
-
max_value
)
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
values_local
[
index
]
-
max_value
)
/
temperature
)
/
float
(
all_sum
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
bfloat16_t
>
)
{
cumsum
+=
exp
((
to
<
float
>
(
values_local
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
float
(
values_local
[
index
])
-
float
(
max_value
))
/
temperature
)
/
float
(
all_sum
);
}
else
if
constexpr
(
std
::
is_same_v
<
Tval
,
half
>
)
{
cumsum
+=
exp
((
to
<
float
>
(
values_local
[
index
])
-
to
<
float
>
(
max_value
))
/
temperature
)
/
to
<
float
>
(
all_sum
);
cumsum
+=
exp
((
float
(
values_local
[
index
])
-
float
(
max_value
))
/
temperature
)
/
float
(
all_sum
);
}
if
(
random_val
<
cumsum
)
{
result
[
0
]
=
indices_global
[
r
*
buf_size
+
index
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment