Commit 00cf8c53 authored by pengcheng888's avatar pengcheng888
Browse files

issue/637 - 为from_list函数添加bf16数据类型的支持

parent 98fff64d
import ml_dtypes
import numpy as np
import torch
......@@ -56,6 +57,8 @@ def numpy_to_infinicore_dtype(numpy_dtype):
return infinicore.float64
elif numpy_dtype == np.float16:
return infinicore.float16
elif numpy_dtype == ml_dtypes.bfloat16:
return infinicore.bfloat16
elif numpy_dtype == np.int8:
return infinicore.int8
elif numpy_dtype == np.int16:
......@@ -82,6 +85,8 @@ def infinicore_to_numpy_dtype(infini_dtype):
return np.int8
elif infini_dtype == infinicore.int16:
return np.int16
elif infini_dtype == infinicore.bfloat16:
return ml_dtypes.bfloat16
elif infini_dtype == infinicore.int32:
return np.int32
elif infini_dtype == infinicore.int64:
......
......@@ -144,8 +144,33 @@ def test4_to():
print(" 简单的测试用例,通过!!")
def test5_bf16():
"""
测试 from_list的bf16的数据类型.
"""
aa = [1.1, 2.2, 3.3]
torch_tensor = torch.tensor(aa, dtype=torch.bfloat16)
print("torch的bf16的数据\n", torch_tensor.dtype, torch_tensor)
infini_tensor = infinicore.from_list(aa, dtype=infinicore.bfloat16)
print("\n\ninfini的bf16的数据类型\n", infini_tensor.dtype)
print("----------------------------------------")
torch_ans_result = torch.zeros(infini_tensor.shape, dtype=torch.bfloat16)
torch_ans = infinicore.from_blob(
torch_ans_result.data_ptr(),
infini_tensor.shape,
dtype=infinicore.bfloat16,
device=infinicore.device("cpu", 0),
)
torch_ans.copy_(infini_tensor)
print("误差:", torch_tensor - torch_ans_result)
if __name__ == "__main__":
test()
test2()
test3()
test4_to()
test5_bf16()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment