GPU_ARCH ?= gfx936
CXX ?= hipcc
CXX_FLAGS ?= -std=c++17 -O3

TARGET := gemv_bench
SRC := main.cpp
LIB := libgemv_bf16.so
LIB_SRC := gemv_export.cpp
DEP := gemv_bf16.h gemv_utils.h hip_compat.h

IS_HIPCC := $(findstring hipcc,$(CXX))

# 根据编译器调整 Flags
ifneq (,$(IS_HIPCC))
    # HIPCC
    ARCH_FLAGS := --offload-arch=$(GPU_ARCH)
else
    # NVCC
    ARCH_FLAGS := -arch=$(GPU_ARCH) -x cu
endif

.PHONY: all clean lib lib.so

all: $(TARGET)

lib: $(LIB)
lib.so: $(LIB)

$(LIB): $(LIB_SRC) $(DEP)
	$(CXX) $(CXX_FLAGS) $(ARCH_FLAGS) -shared -fPIC -o $@ $<

$(TARGET): $(SRC) $(DEP)
	$(CXX) $(CXX_FLAGS) $(ARCH_FLAGS) -o $@ $<

clean:
	rm -f $(TARGET) $(LIB) *.o lib lib.so
