Commit a31b131f authored by yan.yan's avatar yan.yan
Browse files

fix wrong int8 dtype

parent 34e97911
...@@ -39,17 +39,17 @@ class AlgoHint(Enum): ...@@ -39,17 +39,17 @@ class AlgoHint(Enum):
# TODO two step build: build gemm kernels first, then bind for every python # TODO two step build: build gemm kernels first, then bind for every python
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None), "", 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, (64, 32, 32), ["s8,s8,s8,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None), kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
...@@ -164,7 +164,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -164,7 +164,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
...@@ -182,9 +182,9 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -182,9 +182,9 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
] ]
......
...@@ -4,9 +4,18 @@ class CompileInfo: ...@@ -4,9 +4,18 @@ class CompileInfo:
@staticmethod @staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ... def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod @staticmethod
def get_compiled_gemm_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod
def arch_is_compiled(arch: Tuple[int, int]) -> bool: def arch_is_compiled(arch: Tuple[int, int]) -> bool:
""" """
Args: Args:
arch: arch:
""" """
... ...
@staticmethod
def arch_is_compiled_gemm(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
...@@ -57,8 +57,14 @@ def waymo_data_large(batch_size=1): ...@@ -57,8 +57,14 @@ def waymo_data_large(batch_size=1):
pc4[:, 1] += 3 pc4[:, 1] += 3
pc5 = pc.copy() pc5 = pc.copy()
pc5[:, 1] += 4 pc5[:, 1] += 4
pc6 = pc.copy()
pc = np.concatenate([pc, pc2, pc3, pc4, pc5]) pc6[:, 1] += 5
pc7 = pc.copy()
pc7[:, 1] += 6
pc8 = pc.copy()
pc8[:, 1] += 7
pc = np.concatenate([pc, pc2, pc3, pc4, pc5, pc6, pc7, pc8])
print(pc.shape) print(pc.shape)
voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc)) voxels_tv, indices_tv, _ = gen.point_to_voxel(tv.from_numpy(pc))
voxels = voxels_tv.numpy().reshape(-1, 3) voxels = voxels_tv.numpy().reshape(-1, 3)
...@@ -402,7 +408,7 @@ def main(): ...@@ -402,7 +408,7 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = NetSm(spatial_shape, algo).to(device).eval().to(dtype)# .train() net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict()) # net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
...@@ -427,12 +433,17 @@ def main(): ...@@ -427,12 +433,17 @@ def main():
items = list(timer.get_all_pair_time().items()) items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0]) items.sort(key=lambda x: x[0])
print("SUM TIME:", sum([x[1] for x in items])) print("SUM TIME:", sum([x[1] for x in items]))
print(json.dumps(dict(items), indent=2)) # print(json.dumps(dict(items), indent=2))
inds_sum = 0 inds_sum = 0
gemm_sum = 0
for k, v in items: for k, v in items:
if "gen_pairs" in k: if "gen_pairs" in k:
inds_sum += v inds_sum += v
print("SUM GEN INDS:", inds_sum) for k, v in items:
if "gemm" in k:
gemm_sum += v
print("SUM GEN INDS:", inds_sum, "GEMM:", gemm_sum)
# state = net.state_dict() # state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training") # state.pop("net.2.max_num_voxels_during_training")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment