Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
"""FFI APIs for tilelang"""
import tvm._ffi
import tvm.ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm._ffi._init_api("tl", __name__) # pylint: disable=protected-access
tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access
......@@ -3,7 +3,7 @@ from typing import List, Optional, Set, Union
from typing_extensions import Literal
from tvm import ir, tir, DataType
from tvm._ffi import get_global_func
from tvm.ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule, IterVar
from tvm.tir.schedule import BlockRV
......
......@@ -68,15 +68,15 @@ ada_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("e5m2_float8", "float32"),
("e4m3_float8", "float32"),
("float8_e5m2", "float32"),
("float8_e4m3", "float32"),
]
hopper_tensorcore_supported = ada_tensorcore_supported
# TODO(lei): we should consider the dtype of the input a and b
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports e4m3_float8 * e5m2_float8
# As the tensorcore may supports float8_e4m3 * float8_e5m2
def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
if is_volta_arch(arch):
......
......@@ -695,14 +695,14 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support bfloat16, float16, int8, e4m3_float8, e5m2_float8"
"float8_e4m3",
"float8_e5m2",
], "Only support bfloat16, float16, int8, float8_e4m3, float8_e5m2"
# TODO(lei): actually should analyze based on bits instead of dtype
if dtype in ["bfloat16", "float16"]:
ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout
ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
elif dtype in ["int8", "float8_e4m3", "float8_e5m2"]:
# int8 mma only support 32x16 to 16x32 layout
if matrix_name == "A" and trans is False:
ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a
......@@ -760,12 +760,12 @@ def get_ladder_stage3_map(dtype="float16", index_dtype="int32"):
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
"float8_e4m3",
"float8_e5m2",
], "Only support float16, int8, float8_e4m3, float8_e5m2"
if dtype in ["bfloat16", "float16"]:
stage3_layout = shared_32x8_to_mma_32x8_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
elif dtype in ["int8", "float8_e4m3", "float8_e5m2"]:
stage3_layout = shared_32x16_to_mma_32x16_layout
else:
raise ValueError("Unknown dtype ", dtype)
......
......@@ -24,7 +24,7 @@ import subprocess
import sys
from typing import Dict
from tvm._ffi.base import py_str
from tvm.base import py_str
from tvm.contrib import tar as _tar
from tvm.contrib import utils as _utils
......
......@@ -37,10 +37,10 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
import torch
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fn: "float8_e4m3",
torch.float8_e4m3fnuz: "float8_e4m3fnuz",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
torch.float8_e5m2: "float8_e5m2",
torch.float8_e5m2fnuz: "float8_e5m2",
}
def adapt_tensor(arg):
......
......@@ -9,10 +9,10 @@ from __future__ import absolute_import as _abs
import subprocess
import tvm._ffi
import tvm.ffi
from tvm.contrib import utils
from tvm._ffi.base import py_str
from tvm.base import py_str
from tvm.contrib.rocm import get_rocm_arch, find_rocm_path
......@@ -96,7 +96,7 @@ def compile_hip(code,
return data
@tvm._ffi.register_func("tilelang_callback_hip_compile", override=True)
@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
......
......@@ -8,10 +8,10 @@ import subprocess
import warnings
from ..env import CUDA_HOME
import tvm._ffi
import tvm.ffi
from tvm.target import Target
from tvm._ffi.base import py_str
from tvm.base import py_str
from tvm.contrib import utils
......@@ -181,14 +181,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")
@tvm._ffi.register_func("tilelang_callback_cuda_compile", override=True)
@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm._ffi.register_func("tilelang_callback_libdevice_path", override=True)
@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
......@@ -253,7 +253,7 @@ def callback_libdevice_path(arch):
return ""
@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True)
@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
......@@ -391,7 +391,7 @@ def have_cudagraph():
return False
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
......@@ -404,7 +404,7 @@ def have_bf16(compute_version):
return major >= 8
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
......@@ -421,7 +421,7 @@ def have_fp8(compute_version):
return any(conditions)
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not
......
......@@ -21,8 +21,8 @@ import subprocess
import os
from os.path import join, exists
import tvm._ffi
from tvm._ffi.base import py_str
import tvm.ffi
from tvm.base import py_str
import tvm.runtime
import tvm.target
......@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg)
@tvm._ffi.register_func("tvm_callback_rocm_link", override=True)
@tvm.ffi.register_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
......@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return cobj_bin
@tvm._ffi.register_func("tvm_callback_rocm_bitcode_path", override=True)
@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes
......@@ -226,7 +226,7 @@ def have_matrixcore(compute_version=None):
return False
@tvm._ffi.register_func("tvm_callback_rocm_get_arch", override=True)
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"):
"""Utility function to get the AMD GPU architecture
......
......@@ -29,9 +29,11 @@ def has_device_kernel_launch(attrs) -> bool:
def is_device_call_c_device(func: tir.PrimFunc):
attrs = func.attrs
calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT)
is_cpacked = (calling_conv == CallingConv.C_PACKED_FUNC)
# Check if it's a C target
if "target" in attrs and attrs["target"].kind.name == "c":
if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked:
return True
return has_device_kernel_launch(attrs)
......@@ -130,7 +132,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
return target_host
......@@ -145,9 +147,9 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
host_mod = tvm.ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError(f"Target host {target_host.kind.name} is not supported")
return host_mod
......@@ -159,9 +161,9 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
......@@ -173,17 +175,17 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip_without_compile")(
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(
device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
......
......@@ -13,8 +13,7 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
# Warp specialized pass is recommended for Hopper or later architectures
if not is_cuda_target(target) or not have_tma(target):
if (not is_cuda_target(target)) or (not have_tma(target)):
return False
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not disable_warp_specialized
......@@ -109,7 +108,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block
# so we need to lower the opaque block first
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
......@@ -124,15 +123,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
......@@ -153,7 +151,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# the Legalization.
mod = tilelang.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
......@@ -178,9 +176,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
# Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
......
......@@ -25,8 +25,8 @@ class MatrixCoreIntrinEmitter(object):
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
"float8_e4m3fnuz": "e4m3fnuz",
}
......
......@@ -28,8 +28,8 @@ class TensorCoreIntrinEmitter(object):
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
......
......@@ -78,7 +78,7 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# Basic Tensor Core Matrix Multiply operation Unit
micro_size_x = micro_size_y = 16
micro_size_k = 16
if dtype in {"e4m3_float8", "e5m2_float8", "int8"}:
if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:
micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k
......
......@@ -6,7 +6,7 @@ import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
from tvm.relax import TensorType
from tvm import tir
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
......
......@@ -7,7 +7,7 @@ from tilelang import tvm as tvm
from tvm.target import Target
from tilelang.engine.param import KernelParam
from tvm import tir
from tvm.relay import TensorType
from tvm.relax import TensorType
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
......
......@@ -180,8 +180,8 @@ class TLCUDASourceWrapper(object):
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "fp8_e4_t",
"e5m2_float8": "fp8_e5_t",
"float8_e4m3": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t",
"float64": "double",
"int64": "int64_t",
"int32": "int",
......@@ -559,8 +559,8 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"float32": "ctypes.c_float",
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"e4m3_float8": "ctypes.c_uint8",
"e5m2_float8": "ctypes.c_uint8",
"float8_e4m3": "ctypes.c_uint8",
"float8_e5m2": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
"int32": "ctypes.c_int32",
......@@ -766,8 +766,8 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "fp8_e4_t",
"e5m2_float8": "fp8_e5_t",
"float8_e4m3": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t",
"float8_e4m3fnuz": "fp8_e4_t",
"e4m3fnuz_float8": "fp8_e4_t",
"float64": "double",
......
......@@ -17,6 +17,6 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
import tvm._ffi
import tvm.ffi
tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
......@@ -1428,19 +1428,19 @@ float16x64 = func_gen(("Float16x64"))
float32x64 = func_gen(("Float32x64"))
float64x64 = func_gen(("Float64x64"))
e4m3_float8 = func_gen(("E4M3Float8"))
e4m3_float8x4 = func_gen(("E4M3Float8x4"))
e4m3_float8x8 = func_gen(("E4M3Float8x8"))
e4m3_float8x16 = func_gen(("E4M3Float8x16"))
e4m3_float8x32 = func_gen(("E4M3Float8x32"))
e4m3_float8x64 = func_gen(("E4M3Float8x64"))
e5m2_float8 = func_gen(("E5M2Float8"))
e5m2_float8x4 = func_gen(("E5M2Float8x4"))
e5m2_float8x8 = func_gen(("E5M2Float8x8"))
e5m2_float8x16 = func_gen(("E5M2Float8x16"))
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
e5m2_float8x64 = func_gen(("E5M2Float8x64"))
float8_e4m3 = func_gen(("E4M3Float8"))
float8_e4m3x4 = func_gen(("E4M3Float8x4"))
float8_e4m3x8 = func_gen(("E4M3Float8x8"))
float8_e4m3x16 = func_gen(("E4M3Float8x16"))
float8_e4m3x32 = func_gen(("E4M3Float8x32"))
float8_e4m3x64 = func_gen(("E4M3Float8x64"))
float8_e5m2 = func_gen(("E5M2Float8"))
float8_e5m2x4 = func_gen(("E5M2Float8x4"))
float8_e5m2x8 = func_gen(("E5M2Float8x8"))
float8_e5m2x16 = func_gen(("E5M2Float8x16"))
float8_e5m2x32 = func_gen(("E5M2Float8x32"))
float8_e5m2x64 = func_gen(("E5M2Float8x64"))
# pylint: enable=invalid-name
......@@ -1964,33 +1964,33 @@ __all__ = [
"uint16x64",
"uint32x64",
"uint64x64",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"float16",
"float32",
"float64",
"e4m3_float8x4",
"e5m2_float8x4",
"float8_e4m3x4",
"float8_e5m2x4",
"float16x4",
"float32x4",
"float64x4",
"e4m3_float8x8",
"e5m2_float8x8",
"float8_e4m3x8",
"float8_e5m2x8",
"float16x8",
"float32x8",
"float64x8",
"e4m3_float8x16",
"e5m2_float8x16",
"float8_e4m3x16",
"float8_e5m2x16",
"float16x16",
"float32x16",
"float64x16",
"e4m3_float8x32",
"e5m2_float8x32",
"float8_e4m3x32",
"float8_e5m2x32",
"float16x32",
"float32x32",
"float64x32",
"e4m3_float8x64",
"e5m2_float8x64",
"float8_e4m3x64",
"float8_e5m2x64",
"float16x64",
"float32x64",
"float64x64",
......
......@@ -2,6 +2,7 @@
from typing import Union, List, Optional
from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir
......@@ -109,6 +110,11 @@ def copy(
return data.shape
elif isinstance(data, tir.BufferRegion):
return [x.extent for x in data.region]
elif isinstance(data, tir.BufferLoad):
region = get_buffer_region_from_load(data)
if region is None:
return None
return [x.extent for x in region.region]
else:
return None
......@@ -126,6 +132,11 @@ def copy(
return buffer_to_tile_region(data, access_type)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
region = get_buffer_region_from_load(data)
if region is None:
return buffer_load_to_tile_region(data, access_type, extent)
return buffer_region_to_tile_region(region, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment