"...resnet50_tensorflow.git" did not exist on "d4112a1b78dd9ae4d556d4d22d48ce1ed2188f8d"
  • Lei Wang's avatar
    [Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444) · dda45126
    Lei Wang authored
    
    
    * [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method
    
    - Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions.
    - Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation.
    - Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization.
    
    * [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types
    
    - Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively.
    - Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64).
    - Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions.
    - Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing.
    - Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch.
    - Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively.
    
    * lint fix
    
    * Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file.
    
    * [Enhancement] Add support for FP4 and FP8 types in CUDA code generation
    
    - Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types.
    - Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch.
    - Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations.
    - Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations.
    - Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows.
    
    ---------
    Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
    dda45126
codegen_cuda.h 6.44 KB