Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
08a29c28
Commit
08a29c28
authored
Feb 24, 2025
by
xgqdut2016
Browse files
issue/66: modified random sample test function
parent
04aa18f6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
54 deletions
+49
-54
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+0
-1
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+49
-53
No files found.
test/infiniop/causal_softmax.py
View file @
08a29c28
...
@@ -21,7 +21,6 @@ from libinfiniop import (
...
@@ -21,7 +21,6 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# Configuration (Internal Use Only)
# ==============================================================================
# ==============================================================================
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# x_shape, x_stride
# x_shape, x_stride
((
32
,
512
),
None
),
((
32
,
512
),
None
),
...
...
test/infiniop/random_sample.py
View file @
08a29c28
...
@@ -22,7 +22,6 @@ from libinfiniop import (
...
@@ -22,7 +22,6 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# Configuration (Internal Use Only)
# ==============================================================================
# ==============================================================================
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# voc, random_val, topp, topk, temperature
# voc, random_val, topp, topk, temperature
(
512
,
0.8
,
0.8
,
3
,
0.5
),
(
512
,
0.8
,
0.8
,
3
,
0.5
),
...
@@ -59,53 +58,52 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
...
@@ -59,53 +58,52 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def
random_sample
(
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
,
torch_device
):
def
random_sample
(
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
,
torch_device
):
indices
=
torch
.
zeros
([
topk
],
dtype
=
torch
.
int64
)
if
topp
>
0
and
topk
>
1
:
dataNp
=
data
.
clone
().
detach
()
indices
=
torch
.
zeros
([
topk
],
dtype
=
torch
.
int64
)
sorted_indices
=
torch
.
arange
(
voc
)
dataNp
=
data
.
clone
().
detach
()
sorted_indices
=
torch
.
arange
(
voc
)
for
i
in
range
(
topk
):
for
j
in
range
(
i
+
1
,
voc
):
for
i
in
range
(
topk
):
if
dataNp
[
i
]
<
dataNp
[
j
]:
for
j
in
range
(
i
+
1
,
voc
):
tmp
=
dataNp
[
i
].
clone
().
detach
()
if
dataNp
[
i
]
<
dataNp
[
j
]:
dataNp
[
i
]
=
dataNp
[
j
].
clone
().
detach
()
tmp
=
dataNp
[
i
].
clone
().
detach
()
dataNp
[
j
]
=
tmp
dataNp
[
i
]
=
dataNp
[
j
].
clone
().
detach
()
dataNp
[
j
]
=
tmp
tmpInd
=
sorted_indices
[
i
].
clone
().
detach
()
sorted_indices
[
i
]
=
sorted_indices
[
j
].
clone
().
detach
()
tmpInd
=
sorted_indices
[
i
].
clone
().
detach
()
sorted_indices
[
j
]
=
tmpInd
sorted_indices
[
i
]
=
sorted_indices
[
j
].
clone
().
detach
()
sorted_indices
[
j
]
=
tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices
=
sorted_indices
[:
topk
]
# sorted_indices = torch.argsort(dataNp, descending=True)
indices
=
sorted_indices
[:
topk
]
dataNp
=
dataNp
[
sorted_indices
]
dataNp
=
dataNp
[
sorted_indices
]
globalM
=
dataNp
[
0
]
dataNp
=
(
dataNp
-
globalM
)
/
temperature
globalM
=
dataNp
[
0
]
dataNp
=
torch
.
softmax
(
dataNp
.
float
(),
dim
=
0
)
dataNp
=
(
dataNp
-
globalM
)
/
temperature
sum_s
=
0
dataNp
=
torch
.
softmax
(
dataNp
.
float
(),
dim
=
0
)
for
end
in
range
(
topk
):
sum_s
=
0
sum_s
+=
dataNp
[
end
]
for
end
in
range
(
topk
):
if
sum_s
>=
topp
:
sum_s
+=
dataNp
[
end
]
break
if
sum_s
>=
topp
:
if
end
<
topk
-
1
:
break
end
+=
1
if
end
<
topk
-
1
:
end
+=
1
else
:
end
=
topk
sum_s
=
0
for
i
in
range
(
end
):
sum_s
+=
dataNp
[
i
]
random_val
*=
sum_s
sum_s
=
0
for
i
in
range
(
end
):
sum_s
+=
dataNp
[
i
]
if
random_val
<
sum_s
:
return
indices
[
i
]
else
:
else
:
end
=
topk
return
torch
.
argmax
(
data
)
sum_s
=
0
for
i
in
range
(
end
):
sum_s
+=
dataNp
[
i
]
random_val
*=
sum_s
sum_s
=
0
for
i
in
range
(
end
):
sum_s
+=
dataNp
[
i
]
if
random_val
<
sum_s
:
return
indices
[
i
]
def
random_sample_0
(
data
):
return
torch
.
argmax
(
data
)
def
test
(
def
test
(
...
@@ -124,12 +122,10 @@ def test(
...
@@ -124,12 +122,10 @@ def test(
data
=
torch
.
arange
(
voc
).
float
()
*
0.0001
data
=
torch
.
arange
(
voc
).
float
()
*
0.0001
_perm
=
torch
.
randperm
(
voc
)
_perm
=
torch
.
randperm
(
voc
)
data
=
data
[
_perm
].
to
(
x_dtype
).
to
(
torch_device
)
data
=
data
[
_perm
].
to
(
x_dtype
).
to
(
torch_device
)
if
topp
>
0
and
topk
>
1
:
ans
=
random_sample
(
ans
=
random_sample
(
data
.
to
(
"cpu"
),
random_val
,
topp
,
topk
,
voc
,
temperature
,
"cpu"
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
,
torch_device
)
)
# 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
else
:
ans
=
random_sample_0
(
data
)
indices
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int64
).
to
(
torch_device
)
indices
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int64
).
to
(
torch_device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment