Commit 39c133c4 authored by PanZezhong's avatar PanZezhong
Browse files

issue/48 support all int type pos_id, add rope to CI

parent 025894f3
......@@ -12,11 +12,12 @@ os.chdir(PROJECT_DIR)
def run_tests(args):
failed = []
for test in [
"causal_softmax.py",
"gemm.py",
"random_sample.py",
"rms_norm.py",
"causal_softmax.py",
"rope.py",
"swiglu.py",
"random_sample.py",
]:
result = subprocess.run(
f"python {test} {args}", text=True, encoding="utf-8", shell=True
......
......@@ -86,6 +86,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
......
......@@ -77,6 +77,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
......
......@@ -79,7 +79,7 @@ public:
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(pos_type, INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
......
......@@ -41,6 +41,11 @@
return INFINI_STATUS_BAD_TENSOR_DTYPE); \
} while (0)
#define CHECK_DTYPE_ANY_INT(DT) \
CHECK_DTYPE(DT, \
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64, \
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
#define CHECK_SAME_VEC(ERR, FIRST, ...) \
do { \
for (const auto &shape___ : {__VA_ARGS__}) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment