Unverified Commit 674120e1 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/620 处理在python代码中tensor从gpu to到cpu上出错的问题 (#621)


Co-authored-by: default avatarpengcheng888 <pengcheng@example.com>
parents ed012302 ed66e9a0
......@@ -173,8 +173,6 @@ std::shared_ptr<TensorImpl> TensorImpl::empty(const Shape &shape,
auto t = std::shared_ptr<TensorImpl>(new TensorImpl(shape, dtype));
t->data_.offset = 0;
context::setDevice(device);
if (device == Device::Type::CPU) {
if (pin_memory) {
if (context::getDevice() == Device::Type::CPU) {
......@@ -187,6 +185,7 @@ std::shared_ptr<TensorImpl> TensorImpl::empty(const Shape &shape,
t->data_.memory = context::allocateHostMemory(t->numel() * dsize(dtype));
}
} else {
context::setDevice(device);
t->data_.memory = context::allocateMemory(t->numel() * dsize(dtype));
}
......@@ -203,8 +202,6 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_empty(
auto impl = std::shared_ptr<TensorImpl>(new TensorImpl(shape, strides, dtype));
impl->data_.offset = 0;
context::setDevice(device);
size_t max_offset = 0;
for (size_t i = 0; i < shape.size(); ++i) {
......@@ -228,6 +225,7 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_empty(
impl->data_.memory = context::allocateHostMemory(required_bytes);
}
} else {
context::setDevice(device);
impl->data_.memory = context::allocateMemory(required_bytes);
}
......
......@@ -107,7 +107,45 @@ def test3():
print("abs error: ", torch.abs(ans_torch_ref - torch_ans_result).max())
def test4_to():
"""
解决在python代码中 tensor从gpu to到cpu上出错的问题.
"""
if True:
x = torch.rand((2, 3), dtype=torch.float32, device="cpu")
x_infini = infinicore.from_torch(x.clone())
print(" ---------------> test 1")
x_infini.debug()
x_gpu = x_infini.to(infinicore.device("cuda", 0))
x_gpu = x_gpu.to(infinicore.device("cuda", 0))
x_gpu.debug()
x_cpu = x_infini.to(infinicore.device("cpu", 0))
x_cpu = x_cpu.to(infinicore.device("cpu", 0))
x_cpu.debug()
if True:
x = infinicore.empty(
(2, 3), dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
x.debug()
x.to(infinicore.device("cuda", 0))
x_cpu = x.to(infinicore.device("cpu", 0))
x_cpu.debug()
x_cpu_gpu = x_cpu.to(infinicore.device("cuda", 0))
x_cpu_gpu.debug()
x_gpu = x.to(infinicore.device("cuda", 0))
x_gpu.debug()
print(" 简单的测试用例,通过!!")
if __name__ == "__main__":
# test()
test()
test2()
test3()
test4_to()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment