Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0883d6ee
Unverified
Commit
0883d6ee
authored
Nov 17, 2025
by
PanZezhong1725
Committed by
GitHub
Nov 17, 2025
Browse files
Merge pull request #605 from InfiniTensor/issue/603
Issue/603 - 优化张量复制逻辑
parents
17f65139
cf4403d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
62 deletions
+74
-62
src/infinicore/ops/random_sample/random_sample_infiniop.cc
src/infinicore/ops/random_sample/random_sample_infiniop.cc
+1
-1
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+73
-61
No files found.
src/infinicore/ops/random_sample/random_sample_infiniop.cc
View file @
0883d6ee
...
@@ -35,7 +35,7 @@ static void calculate(
...
@@ -35,7 +35,7 @@ static void calculate(
if
(
!
desc_opt
)
{
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateRandomSampleDescriptor
(
INFINICORE_CHECK_ERROR
(
infiniopCreateRandomSampleDescriptor
(
context
::
getInfiniopHandle
(),
&
desc
,
context
::
getInfiniopHandle
(
indices
->
device
()
),
&
desc
,
indices
->
desc
(),
logits
->
desc
()));
indices
->
desc
(),
logits
->
desc
()));
cache
.
put
(
seed
,
desc
);
cache
.
put
(
seed
,
desc
);
}
else
{
}
else
{
...
...
test/infinicore/framework/base.py
View file @
0883d6ee
...
@@ -260,9 +260,7 @@ class TestRunner:
...
@@ -260,9 +260,7 @@ class TestRunner:
return
False
return
False
except
Exception
as
e
:
except
Exception
as
e
:
error_msg
=
(
error_msg
=
f
"Error:
{
e
}
"
f
"
{
test_case
}
-
{
InfiniDeviceNames
[
device
]
}
- Error:
{
e
}
"
)
print
(
f
"
\033
[91m✗
\033
[0m
{
error_msg
}
"
)
print
(
f
"
\033
[91m✗
\033
[0m
{
error_msg
}
"
)
self
.
failed_tests
.
append
(
error_msg
)
self
.
failed_tests
.
append
(
error_msg
)
...
@@ -392,7 +390,7 @@ class BaseOperatorTest(ABC):
...
@@ -392,7 +390,7 @@ class BaseOperatorTest(ABC):
return
spec
.
create_torch_tensor
(
device
)
return
spec
.
create_torch_tensor
(
device
)
return
spec
return
spec
def
prepare_inputs_and_kwargs
(
self
,
test_case
,
device
):
def
prepare_
pytorch_
inputs_and_kwargs
(
self
,
test_case
,
device
):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
"""
"""
...
@@ -455,6 +453,71 @@ class BaseOperatorTest(ABC):
...
@@ -455,6 +453,71 @@ class BaseOperatorTest(ABC):
return
inputs
,
kwargs
return
inputs
,
kwargs
def
prepare_infinicore_list
(
self
,
input_sequence
,
clone
=
False
):
cloned_tensors
=
[]
infini_list
=
[]
for
item
in
input_sequence
:
if
isinstance
(
item
,
torch
.
Tensor
):
if
clone
:
cloned_item
=
item
.
clone
().
detach
()
infini_item
=
infinicore_tensor_from_torch
(
cloned_item
)
cloned_tensors
.
append
(
cloned_item
)
else
:
infini_item
=
infinicore_tensor_from_torch
(
item
)
else
:
infini_item
=
item
infini_list
.
append
(
infini_item
)
return
infini_list
,
cloned_tensors
def
prepare_infinicore_inputs_and_kwargs
(
self
,
inputs
,
kwargs
,
comparison_target
):
cloned_tensors
=
[]
infini_inputs
=
[]
# Prepare infinicore inputs - only clone if needed for comparison
for
i
,
inp
in
enumerate
(
inputs
):
if
isinstance
(
inp
,
torch
.
Tensor
):
# Clone only if this input will be used for comparison
if
comparison_target
==
i
:
cloned_inp
=
inp
.
clone
().
detach
()
infini_tensor
=
infinicore_tensor_from_torch
(
cloned_inp
)
cloned_tensors
.
append
(
cloned_inp
)
else
:
# For non-comparison inputs, we can use the original (but still need to convert)
infini_tensor
=
infinicore_tensor_from_torch
(
inp
)
infini_inputs
.
append
(
infini_tensor
)
elif
isinstance
(
inp
,
(
tuple
,
list
)):
infini_list
,
cloned_list
=
self
.
prepare_infinicore_list
(
inp
,
comparison_target
==
i
)
infini_inputs
.
append
(
infini_list
)
cloned_tensors
.
append
(
cloned_list
)
else
:
infini_inputs
.
append
(
inp
)
# Prepare infinicore kwargs
infini_kwargs
=
{}
for
key
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Check if this tensor is used for output comparison
if
key
==
"out"
and
comparison_target
==
"out"
:
cloned_value
=
value
.
clone
().
detach
()
infini_kwargs
[
key
]
=
infinicore_tensor_from_torch
(
cloned_value
)
cloned_tensors
.
append
(
cloned_value
)
elif
key
==
"out"
and
isinstance
(
comparison_target
,
int
):
infini_kwargs
[
key
]
=
infini_inputs
[
comparison_target
]
else
:
infini_kwargs
[
key
]
=
infinicore_tensor_from_torch
(
value
)
elif
isinstance
(
value
,
(
tuple
,
list
)):
infini_list
,
cloned_list
=
self
.
prepare_infinicore_list
(
value
,
key
==
"out"
)
cloned_tensors
.
append
(
cloned_list
)
infini_kwargs
[
key
]
=
infini_list
else
:
infini_kwargs
[
key
]
=
value
return
infini_inputs
,
infini_kwargs
,
cloned_tensors
def
run_test
(
self
,
device
,
test_case
,
config
):
def
run_test
(
self
,
device
,
test_case
,
config
):
"""
"""
Unified test execution flow
Unified test execution flow
...
@@ -478,66 +541,15 @@ class BaseOperatorTest(ABC):
...
@@ -478,66 +541,15 @@ class BaseOperatorTest(ABC):
)
)
# Prepare inputs and kwargs with actual tensors
# Prepare inputs and kwargs with actual tensors
inputs
,
kwargs
=
self
.
prepare_inputs_and_kwargs
(
test_case
,
device
)
inputs
,
kwargs
=
self
.
prepare_pytorch_inputs_and_kwargs
(
test_case
,
device
)
# For in-place operations on input tensors, we need to preserve the original state
original_inputs
=
[]
if
"out"
in
kwargs
and
isinstance
(
kwargs
[
"out"
],
torch
.
Tensor
):
# This is an in-place operation on an input tensor
# Store original values for comparison
for
inp
in
inputs
:
if
isinstance
(
inp
,
torch
.
Tensor
):
original_inputs
.
append
(
inp
.
clone
().
detach
())
else
:
original_inputs
.
append
(
inp
)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs
=
[]
torch_input_clones
=
[]
for
inp
in
inputs
:
if
isinstance
(
inp
,
torch
.
Tensor
):
cloned_inp
=
inp
.
clone
().
detach
()
torch_input_clones
.
append
(
cloned_inp
)
infini_tensor
=
infinicore_tensor_from_torch
(
cloned_inp
)
infini_inputs
.
append
(
infini_tensor
)
else
:
infini_inputs
.
append
(
inp
)
infini_kwargs
=
{}
for
key
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Clone tensor and convert to infinicore
cloned_value
=
value
.
clone
().
detach
()
torch_input_clones
.
append
(
cloned_value
)
infini_kwargs
[
key
]
=
infinicore_tensor_from_torch
(
cloned_value
)
else
:
# Pass through non-tensor values (scalars, strings, etc.)
infini_kwargs
[
key
]
=
value
# Determine comparison target
# Determine comparison target
comparison_target
=
test_case
.
comparison_target
comparison_target
=
test_case
.
comparison_target
# Handle infinicore output
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_kwargs
=
kwargs
.
copy
()
infini_inputs
,
infini_kwargs
,
cloned_tensors
=
(
if
"out"
in
infini_kwargs
:
self
.
prepare_infinicore_inputs_and_kwargs
(
inputs
,
kwargs
,
comparison_target
)
out_value
=
infini_kwargs
[
"out"
]
)
if
isinstance
(
out_value
,
torch
.
Tensor
):
# Single tensor output
if
isinstance
(
comparison_target
,
int
):
infini_kwargs
[
"out"
]
=
infini_inputs
[
comparison_target
]
else
:
cloned_out
=
out_value
.
clone
().
detach
()
torch_input_clones
.
append
(
cloned_out
)
infini_kwargs
[
"out"
]
=
infinicore_tensor_from_torch
(
cloned_out
)
elif
isinstance
(
out_value
,
(
tuple
,
list
)):
# Multiple tensor outputs
infini_outputs
=
[]
for
tensor
in
out_value
:
cloned_tensor
=
tensor
.
clone
().
detach
()
torch_input_clones
.
append
(
cloned_tensor
)
infini_outputs
.
append
(
infinicore_tensor_from_torch
(
cloned_tensor
))
infini_kwargs
[
"out"
]
=
tuple
(
infini_outputs
)
# Check operator implementations
# Check operator implementations
torch_implemented
=
True
torch_implemented
=
True
...
@@ -698,7 +710,7 @@ class BaseOperatorTest(ABC):
...
@@ -698,7 +710,7 @@ class BaseOperatorTest(ABC):
is_valid
=
compare_fn
(
infini_comparison
,
torch_comparison
)
is_valid
=
compare_fn
(
infini_comparison
,
torch_comparison
)
if
not
is_valid
:
if
not
is_valid
:
raise
AssertionError
(
f
"Result comparison failed
for
{
test_case
}
"
)
raise
AssertionError
(
f
"Result comparison failed
.
"
)
# ==========================================================================
# ==========================================================================
# UNIFIED BENCHMARKING LOGIC
# UNIFIED BENCHMARKING LOGIC
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment