Commit 7a8f2bca authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: modify swiglu test to correctly handle broadcast scenarios, add two...

issue/127: modify swiglu test to correctly handle broadcast scenarios, add two broadcast testcases, correct elementwise cpu mix-precision implementation
parent 6292da00
......@@ -84,7 +84,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
};
out[out_idx] = utils::cast<Tout>(Op{}(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
......
......@@ -25,8 +25,10 @@ _TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
......@@ -76,6 +78,38 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
"""
rearrange the tensors if needed and apply the inplace config.
if inplace is true and the output (i.e., c) is placed to the broadcasted input,
the inplace config is ignored and out-of-place is used
"""
original_c_strides = c_strides if c_strides else c.stride()
def _rearrange(tensor, strides):
if strides and 0 in strides:
tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides)
return tensor
else:
return rearrange_if_needed(tensor, strides)
a, b, c = [
_rearrange(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
# if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides
if 0 in c.stride():
c.set_(c.untyped_storage(), 0, c.shape, original_c_strides)
return a, b, c
def test(
......@@ -98,18 +132,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = swiglu(a, b)
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment