assertlen(set(m_splits))==1,"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assertlen(set(m_splits))==1,"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assertnotgelu,"GELU not supported with int8 simulation groupgemm."
assertnotgelu,"GELU not supported with int8 simulation groupgemm."
assertnotuse_bias,"Bias not supported with int8 simulation groupgemm."
assertnotuse_bias,"Bias not supported with int8 simulation groupgemm."
assertout_dtypeistorch.bfloat16orout_dtypeistorch.float32,"Out_dtype must be bfloat16 or float32 for int8 simulation"
assertTE_DType_To_Torch[out_dtype]istorch.bfloat16orTE_DType_To_Torch[out_dtype]istorch.float32,"Out_dtype must be bfloat16 or float32 for int8 simulation"
assertlen(set(m_splits))==1,"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assertlen(set(m_splits))==1,"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assertnotgelu,"GELU not supported with int8 simulation groupgemm."
assertnotgelu,"GELU not supported with int8 simulation groupgemm."
assertnotuse_bias,"Bias not supported with int8 simulation groupgemm."
assertnotuse_bias,"Bias not supported with int8 simulation groupgemm."
assertout_dtypeistorch.bfloat16orout_dtypeistorch.float32,"Out_dtype must be bfloat16 or float32 for int8 simulation"
assertTE_DType_To_Torch[out_dtype]istorch.bfloat16orTE_DType_To_Torch[out_dtype]istorch.float32,"Out_dtype must be bfloat16 or float32 for int8 simulation"