import jax import jax.numpy as jnp from jax import random, grad, jit, vmap import time # 检查JAX是否使用了GPU print("JAX版本:", jax.__version__) print("可用设备:", jax.devices()) print("当前使用的设备:", jax.devices()[0]) print("设备类型:", jax.devices()[0].device_kind) # 1. 基本的GPU矩阵运算示例 def basic_gpu_operations(): print("\n=== 基础GPU矩阵运算 ===") # 创建大型矩阵(这些会自动在GPU上分配) key = random.PRNGKey(42) size = 5000 # 大型矩阵 # 生成随机矩阵 A = random.normal(key, (size, size)) B = random.normal(key, (size, size)) # GPU矩阵乘法 start = time.time() C = jnp.dot(A, B) # 这会自动在GPU上执行 compute_time = time.time() - start print(f"矩阵大小: {size}x{size}") print(f"GPU矩阵乘法耗时: {compute_time:.4f}秒") # 检查结果 print(f"结果矩阵C的形状: {C.shape}") print(f"C的平均值: {jnp.mean(C):.6f}") print(f"C的标准差: {jnp.std(C):.6f}") return C # 2. 使用JIT编译加速 def jit_acceleration(): print("\n=== JIT编译加速示例 ===") # 定义一个复杂函数 def complex_function(x): return jnp.sum(jnp.sin(x) * jnp.exp(-x**2) + jnp.log1p(jnp.abs(x))) # 创建输入数据 key = random.PRNGKey(0) x = random.normal(key, (10000, 1000)) # 不使用JIT start = time.time() result_no_jit = complex_function(x) time_no_jit = time.time() - start # 使用JIT编译 jitted_function = jit(complex_function) # 第一次调用会有编译开销 start = time.time() result_jit = jitted_function(x) time_first_call = time.time() - start # 第二次调用(已经编译) start = time.time() result_jit = jitted_function(x) time_second_call = time.time() - start print(f"不使用JIT: {time_no_jit:.4f}秒") print(f"JIT首次调用(含编译): {time_first_call:.4f}秒") print(f"JIT后续调用: {time_second_call:.4f}秒") print(f"加速比: {time_no_jit/time_second_call:.2f}倍") return result_jit # 3. 自动微分和GPU计算 def autodiff_gpu(): print("\n=== GPU上的自动微分 ===") # 定义损失函数 def loss(params, x, y): w, b = params predictions = jnp.dot(x, w) + b return jnp.mean((predictions - y)**2) # 生成模拟数据 key = random.PRNGKey(1) key1, key2, key3 = random.split(key, 3) n_samples = 10000 n_features = 100 X = random.normal(key1, (n_samples, n_features)) true_w = random.normal(key2, (n_features, 1)) true_b = random.normal(key3, (1,)) Y = jnp.dot(X, true_w) + true_b + 0.1 * random.normal(key1, (n_samples, 1)) # 初始化参数 params = (random.normal(key1, (n_features, 1)), jnp.zeros((1,))) # 计算梯度的GPU函数 grad_loss = jit(grad(loss, argnums=0)) # 在GPU上计算梯度 start = time.time() gradients = grad_loss(params, X, Y) grad_time = time.time() - start print(f"数据形状: X={X.shape}, Y={Y.shape}") print(f"GPU梯度计算耗时: {grad_time:.4f}秒") print(f"权重梯度形状: {gradients[0].shape}") print(f"偏置梯度形状: {gradients[1].shape}") return gradients # 4. 批量处理示例 def batch_processing(): print("\n=== GPU批量处理 ===") # 定义一个处理单个样本的函数 def process_sample(x): # 一些复杂的变换 return jnp.sum(jnp.tanh(x) * jnp.cos(x)) # 使用vmap进行向量化 batched_process = vmap(process_sample) # 创建批量数据 key = random.PRNGKey(3) batch_size = 10000 sample_size = 1000 batch_data = random.normal(key, (batch_size, sample_size)) # 逐个处理(慢) def sequential_process(data): results = [] for i in range(data.shape[0]): results.append(process_sample(data[i])) return jnp.array(results) # 计时 start = time.time() sequential_results = sequential_process(batch_data) sequential_time = time.time() - start # 使用vmap批量处理(在GPU上并行) start = time.time() batched_results = batched_process(batch_data) batched_time = time.time() - start print(f"批量大小: {batch_size}, 样本大小: {sample_size}") print(f"逐个处理耗时: {sequential_time:.4f}秒") print(f"批量处理耗时: {batched_time:.4f}秒") print(f"加速比: {sequential_time/batched_time:.2f}倍") # 验证结果一致性 error = jnp.max(jnp.abs(sequential_results - batched_results)) print(f"结果最大误差: {error:.10f}") return batched_results # 5. 高级示例:简单的神经网络层 def neural_network_example(): print("\n=== GPU上的神经网络层 ===") # 定义一个简单的神经网络层 def dense_layer(params, x): w, b = params return jnp.dot(x, w) + b # ReLU激活函数 def relu(x): return jnp.maximum(0, x) # 两层神经网络 def two_layer_net(params, x): w1, b1, w2, b2 = params hidden = relu(jnp.dot(x, w1) + b1) output = jnp.dot(hidden, w2) + b2 return output # 初始化参数 key = random.PRNGKey(4) key1, key2, key3, key4 = random.split(key, 4) input_size = 784 # MNIST图像大小 hidden_size = 128 output_size = 10 # 10个类别 # 初始化权重 w1 = random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size) b1 = jnp.zeros(hidden_size) w2 = random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size) b2 = jnp.zeros(output_size) params = (w1, b1, w2, b2) # 创建批量输入 batch_size = 1024 x_batch = random.normal(key3, (batch_size, input_size)) # JIT编译神经网络 jitted_net = jit(two_layer_net) # 前向传播 start = time.time() outputs = jitted_net(params, x_batch) inference_time = time.time() - start print(f"网络架构: {input_size} -> {hidden_size} -> {output_size}") print(f"批量大小: {batch_size}") print(f"GPU推理耗时: {inference_time:.6f}秒") print(f"输出形状: {outputs.shape}") print(f"每个样本的平均输出: {jnp.mean(outputs, axis=0)}") return outputs def main(): print("="*50) print("JAX GPU Demo - 开始执行") print("="*50) # 运行所有示例 try: # 示例1: 基础GPU运算 result1 = basic_gpu_operations() # 示例2: JIT加速 result2 = jit_acceleration() # 示例3: 自动微分 result3 = autodiff_gpu() # 示例4: 批量处理 result4 = batch_processing() # 示例5: 神经网络 result5 = neural_network_example() print("\n" + "="*50) print("所有示例执行完成!") print("="*50) # 验证GPU确实在工作 print("\n验证GPU使用情况:") with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): # 执行一个操作来查看跟踪 test_matrix = jnp.ones((1000, 1000)) test_result = jnp.dot(test_matrix, test_matrix.T) print(f"跟踪测试完成,结果形状: {test_result.shape}") except Exception as e: print(f"执行出错: {e}") print("\n注意:请确保已经正确安装JAX GPU版本") print("安装命令: pip install --upgrade 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html") if __name__ == "__main__": main()